TestEmbeddingStoreApp.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # @description:
  2. # @author: licanglong
  3. # @date: 2025/12/19 16:32
  4. from app.App import App
  5. from app.utils.pathutils import getpath
  6. class EmbeddingStoreApp(App):
  7. def run(self, *args, **kwargs):
  8. import uuid
  9. from typing import List
  10. from qdrant_client import QdrantClient
  11. from qdrant_client.models import VectorParams, Distance, PointStruct
  12. from sentence_transformers import SentenceTransformer
  13. if not kwargs or not kwargs.get("collection_name"):
  14. raise ValueError("miss collection_name value")
  15. if not kwargs or not kwargs.get("vector_data"):
  16. raise ValueError("miss vector_data value")
  17. collection_name: str = kwargs['collection_name']
  18. vector_size = 1792
  19. vector_data: dict = kwargs['vector_data'] # case_embed rule_embed merchants_embed edges_embed
  20. client = QdrantClient(host="117.72.147.109", port=16333)
  21. model = SentenceTransformer(getpath(r"res\models\acge_text_embedding"))
  22. collections = client.get_collections().collections
  23. exists = any(c.name == collection_name for c in collections)
  24. if not exists:
  25. client.create_collection(
  26. collection_name=collection_name,
  27. vectors_config=VectorParams(
  28. size=vector_size,
  29. distance=Distance.COSINE,
  30. ),
  31. )
  32. points: List[PointStruct] = []
  33. for item in vector_data:
  34. vector = model.encode(item['embedding_text'])
  35. point_id = str(uuid.uuid4())
  36. points.append(
  37. PointStruct(
  38. id=point_id,
  39. vector=vector.tolist(),
  40. payload=item,
  41. )
  42. )
  43. client.upsert(
  44. collection_name=collection_name,
  45. points=points,
  46. )