import json from typing import Literal, Optional, Union import lancedb import pyarrow as pa from lancedb.pydantic import LanceModel qna_schema = pa.schema( [ pa.field("uid", pa.string()), pa.field("vector", pa.list_(pa.float32(), 1536)), pa.field("question", pa.string()), pa.field("answer", pa.string()), pa.field("language", pa.string()), pa.field("source", pa.string()), pa.field("category", pa.string()), ] ) class QnA(LanceModel): uid: str question: str answer: str language: str category: str source: str score: Optional[float] = None class LanceVectorDb: def __init__(self, path): self.db = lancedb.connect(path) if "qna_table" not in self.db.table_names(): self.table = self.db.create_table("qna_table", schema=qna_schema) else: self.table = self.db.open_table("qna_table") def init_from_qna_json(self, path): with open(path, encoding="utf-8") as f: qna_data = json.load(f) qnas = qna_data["qna"] embeddings = qna_data["embeddings"] qnas_with_embeddings = [] for qna in qnas: uid = qna["uid"] emb = embeddings.get(uid) if emb is None: continue qna["vector"] = emb qnas_with_embeddings.append(qna) self.insert(qnas_with_embeddings) def insert(self, data: Union[dict, list[dict]]): if not isinstance(data, list): data = [data] # This step is temporary. They are working on fixing this. columns = list(data[0].keys()) data_columns = {column: [d[column] for d in data] for column in columns} elements_to_insert = pa.Table.from_pydict(data_columns, schema=qna_schema) self.table.add(elements_to_insert) def get_qna( self, vector: list, metric: Literal["L2", "cosine"] = "L2", filters: Optional[dict] = {"language": "de"}, limit=3, ): filters_string = " AND ".join([f"{k} == '{v}'" for k, v in filters.items()]) print(filters_string) query = self.table.search(vector) if filters_string: query = query.where(filters_string) results = query.metric(metric).limit(limit).to_df().to_dict(orient="records") results = [QnA(**result, score=result["_distance"]) for result in results] return results