phone-bot-demo / vector_db.py
neke-leo's picture
ENH: Add new version
42b9715
import json
from typing import Literal, Optional, Union
import lancedb
import pyarrow as pa
from lancedb.pydantic import LanceModel
qna_schema = pa.schema(
[
pa.field("uid", pa.string()),
pa.field("vector", pa.list_(pa.float32(), 1536)),
pa.field("question", pa.string()),
pa.field("answer", pa.string()),
pa.field("language", pa.string()),
pa.field("source", pa.string()),
pa.field("category", pa.string()),
]
)
class QnA(LanceModel):
uid: str
question: str
answer: str
language: str
category: str
source: str
score: Optional[float] = None
class LanceVectorDb:
def __init__(self, path):
self.db = lancedb.connect(path)
if "qna_table" not in self.db.table_names():
self.table = self.db.create_table("qna_table", schema=qna_schema)
else:
self.table = self.db.open_table("qna_table")
def init_from_qna_json(self, path):
with open(path, encoding="utf-8") as f:
qna_data = json.load(f)
qnas = qna_data["qna"]
embeddings = qna_data["embeddings"]
qnas_with_embeddings = []
for qna in qnas:
uid = qna["uid"]
emb = embeddings.get(uid)
if emb is None:
continue
qna["vector"] = emb
qnas_with_embeddings.append(qna)
self.insert(qnas_with_embeddings)
def insert(self, data: Union[dict, list[dict]]):
if not isinstance(data, list):
data = [data]
# This step is temporary. They are working on fixing this.
columns = list(data[0].keys())
data_columns = {column: [d[column] for d in data] for column in columns}
elements_to_insert = pa.Table.from_pydict(data_columns, schema=qna_schema)
self.table.add(elements_to_insert)
def get_qna(
self,
vector: list,
metric: Literal["L2", "cosine"] = "L2",
filters: Optional[dict] = {"language": "de"},
limit=3,
):
filters_string = " AND ".join([f"{k} == '{v}'" for k, v in filters.items()])
print(filters_string)
query = self.table.search(vector)
if filters_string:
query = query.where(filters_string)
results = query.metric(metric).limit(limit).to_df().to_dict(orient="records")
results = [QnA(**result, score=result["_distance"]) for result in results]
return results