import shutil from pathlib import Path from typing import List, Optional, Tuple import tqdm from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings from langchain_community.vectorstores import Chroma from loguru import logger from app.config.models.configs import Config from app.parsers.splitter import Document from app.utils import torch_device class ChromaDenseVectorDB: def __init__(self, persist_folder: str, config: Config): self._persist_folder = persist_folder self._config = config logger.info(f"Embedding model config: {config}") self._embeddings = SentenceTransformerEmbeddings(model_name=config.embeddings.embedding_model.model_name, model_kwargs={"device": torch_device()}) self.batch_size = 200 self._retriever = None self._vectordb = None @property def retriever(self): if self._retriever is None: self._retriever = self._load_retriever() return self._retriever @property def vectordb(self): if self._vectordb is None: self._vectordb = Chroma( persist_directory=self._persist_folder, embedding_function=self._embeddings, ) return self._vectordb def generate_embeddings( self, docs: List[Document], clear_persist_folder: bool = True, ): if clear_persist_folder: pf = Path(self._persist_folder) if pf.exists() and pf.is_dir(): logger.warning(f"Deleting the content of: {pf}") shutil.rmtree(pf) logger.info("Generating and persisting the embeddings..") vectordb = None for group in tqdm.tqdm( chunker(docs, size=self.batch_size), total=int(len(docs) / self.batch_size), ): ids = [d.metadata["document_id"] for d in group] if vectordb is None: vectordb = Chroma.from_documents( documents=group, embedding=self._embeddings, ids=ids, persist_directory=self._persist_folder, ) else: vectordb.add_texts( texts=[doc.page_content for doc in group], embedding=self._embeddings, ids=ids, metadatas=[doc.metadata for doc in group], ) logger.info("Generated embeddings. Persisting...") if vectordb is not None: vectordb.persist() def _load_retriever(self, **kwargs): return self.vectordb.as_retriever(**kwargs) def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: results = self.retriever.vectorstore.get(ids=document_ids, include=["metadatas", "documents"]) # type: ignore docs = [ Document(page_content=d, metadata=m) for d, m in zip(results["documents"], results["metadatas"]) ] return docs def similarity_search_with_relevance_scores( self, query: str, filter: Optional[dict] ) -> List[Tuple[Document, float]]: if isinstance(filter, dict) and len(filter) > 1: filter = {"$and": [{key: {"$eq": value}} for key, value in filter.items()]} print("Filter = ", filter) return self.retriever.vectorstore.similarity_search_with_relevance_scores( query, k=self._config.semantic_search.max_k, filter=filter ) def chunker(seq, size): return (seq[pos: pos + size] for pos in range(0, len(seq), size))