Spaces:
Runtime error
Runtime error
import json | |
from typing import Literal, Optional, Union | |
import lancedb | |
import pyarrow as pa | |
from lancedb.pydantic import LanceModel | |
qna_schema = pa.schema( | |
[ | |
pa.field("uid", pa.string()), | |
pa.field("vector", pa.list_(pa.float32(), 1536)), | |
pa.field("question", pa.string()), | |
pa.field("answer", pa.string()), | |
pa.field("language", pa.string()), | |
pa.field("source", pa.string()), | |
pa.field("category", pa.string()), | |
] | |
) | |
class QnA(LanceModel): | |
uid: str | |
question: str | |
answer: str | |
language: str | |
category: str | |
source: str | |
score: Optional[float] = None | |
class LanceVectorDb: | |
def __init__(self, path): | |
self.db = lancedb.connect(path) | |
if "qna_table" not in self.db.table_names(): | |
self.table = self.db.create_table("qna_table", schema=qna_schema) | |
else: | |
self.table = self.db.open_table("qna_table") | |
def init_from_qna_json(self, path): | |
with open(path, encoding="utf-8") as f: | |
qna_data = json.load(f) | |
qnas = qna_data["qna"] | |
embeddings = qna_data["embeddings"] | |
qnas_with_embeddings = [] | |
for qna in qnas: | |
uid = qna["uid"] | |
emb = embeddings.get(uid) | |
if emb is None: | |
continue | |
qna["vector"] = emb | |
qnas_with_embeddings.append(qna) | |
self.insert(qnas_with_embeddings) | |
def insert(self, data: Union[dict, list[dict]]): | |
if not isinstance(data, list): | |
data = [data] | |
# This step is temporary. They are working on fixing this. | |
columns = list(data[0].keys()) | |
data_columns = {column: [d[column] for d in data] for column in columns} | |
elements_to_insert = pa.Table.from_pydict(data_columns, schema=qna_schema) | |
self.table.add(elements_to_insert) | |
def get_qna( | |
self, | |
vector: list, | |
metric: Literal["L2", "cosine"] = "L2", | |
filters: Optional[dict] = {"language": "de"}, | |
limit=3, | |
): | |
filters_string = " AND ".join([f"{k} == '{v}'" for k, v in filters.items()]) | |
print(filters_string) | |
query = self.table.search(vector) | |
if filters_string: | |
query = query.where(filters_string) | |
results = query.metric(metric).limit(limit).to_df().to_dict(orient="records") | |
results = [QnA(**result, score=result["_distance"]) for result in results] | |
return results | |