Spaces:
Runtime error
Runtime error
File size: 2,475 Bytes
a03292c 42b9715 a03292c 42b9715 a03292c 42b9715 a03292c 42b9715 a03292c 42b9715 a03292c 42b9715 a03292c 42b9715 a03292c 42b9715 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import json
from typing import Literal, Optional, Union
import lancedb
import pyarrow as pa
from lancedb.pydantic import LanceModel
qna_schema = pa.schema(
[
pa.field("uid", pa.string()),
pa.field("vector", pa.list_(pa.float32(), 1536)),
pa.field("question", pa.string()),
pa.field("answer", pa.string()),
pa.field("language", pa.string()),
pa.field("source", pa.string()),
pa.field("category", pa.string()),
]
)
class QnA(LanceModel):
uid: str
question: str
answer: str
language: str
category: str
source: str
score: Optional[float] = None
class LanceVectorDb:
def __init__(self, path):
self.db = lancedb.connect(path)
if "qna_table" not in self.db.table_names():
self.table = self.db.create_table("qna_table", schema=qna_schema)
else:
self.table = self.db.open_table("qna_table")
def init_from_qna_json(self, path):
with open(path, encoding="utf-8") as f:
qna_data = json.load(f)
qnas = qna_data["qna"]
embeddings = qna_data["embeddings"]
qnas_with_embeddings = []
for qna in qnas:
uid = qna["uid"]
emb = embeddings.get(uid)
if emb is None:
continue
qna["vector"] = emb
qnas_with_embeddings.append(qna)
self.insert(qnas_with_embeddings)
def insert(self, data: Union[dict, list[dict]]):
if not isinstance(data, list):
data = [data]
# This step is temporary. They are working on fixing this.
columns = list(data[0].keys())
data_columns = {column: [d[column] for d in data] for column in columns}
elements_to_insert = pa.Table.from_pydict(data_columns, schema=qna_schema)
self.table.add(elements_to_insert)
def get_qna(
self,
vector: list,
metric: Literal["L2", "cosine"] = "L2",
filters: Optional[dict] = {"language": "de"},
limit=3,
):
filters_string = " AND ".join([f"{k} == '{v}'" for k, v in filters.items()])
print(filters_string)
query = self.table.search(vector)
if filters_string:
query = query.where(filters_string)
results = query.metric(metric).limit(limit).to_df().to_dict(orient="records")
results = [QnA(**result, score=result["_distance"]) for result in results]
return results
|