zinoubm's picture
finishing the main feature
5ddfe7e
raw
history blame
3.47 kB
import os
from qdrant_client import QdrantClient
from qdrant_client.http import models
from dotenv import load_dotenv
from uuid import uuid4
load_dotenv()
COLLECTION_NAME = os.getenv("COLLECTION_NAME")
COLLECTION_SIZE = os.getenv("COLLECTION_SIZE")
QDRANT_PORT = os.getenv("QDRANT_PORT")
QDRANT_HOST = os.getenv("QDRANT_HOST")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantManager:
"""
A class for managing collectionsget_collection_info in the Qdrant database.
Args:
collection_name (str): The name of the collection to manage.
collection_size (int): The maximum number of documents in the collection.
port (int): The port number for the Qdrant API.
host (str): The hostname or IP address for the Qdrant server.
api_key (str): The API key for authenticating with the Qdrant server.
recreate_collection (bool): Whether to recreate the collection if it already exists.
Attributes:
client (qdrant_client.QdrantClient): The Qdrant client object for interacting with the API.
"""
def __init__(
self,
collection_name=COLLECTION_NAME,
collection_size: int = COLLECTION_SIZE,
port: int = QDRANT_PORT,
host=QDRANT_HOST,
api_key=QDRANT_API_KEY,
recreate_collection: bool = False,
):
self.collection_name = collection_name
self.collection_size = collection_size
self.host = host
self.port = port
self.api_key = api_key
self.client = QdrantClient(host=host, port=port, api_key=api_key)
self.setup_collection(collection_size, recreate_collection)
def setup_collection(self, collection_size: int, recreate_collection: bool):
if recreate_collection:
self.recreate_collection()
try:
collection_info = self.get_collection_info()
current_collection_size = collection_info["vector_size"]
if current_collection_size != int(collection_size):
raise ValueError(
f"""
Existing collection {self.collection_name} has different collection size
To use the new collection configuration, you need to recreate the collection as it already exists with a different configuration.
use recreate_collection = True.
"""
)
except Exception as e:
print(e)
def recreate_collection(self):
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=self.collection_size, distance=models.Distance.COSINE
),
)
def get_collection_info(self):
collection_info = self.client.get_collection(
collection_name=self.collection_name
)
return {
"points_count": int(collection_info.points_count),
"vectors_count": int(collection_info.vectors_count),
"indexed_vectors_count": int(collection_info.indexed_vectors_count),
"vector_size": int(collection_info.config.params.vectors.size),
}
def search_point(self, query_vector, limit=5):
response = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
limit=limit,
)
return response
qdrant_manager = QdrantManager()