File size: 2,475 Bytes
a03292c
 
 
 
 
 
 
 
 
 
 
 
 
 
42b9715
a03292c
 
 
 
 
 
 
 
 
 
 
42b9715
a03292c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42b9715
a03292c
 
42b9715
 
a03292c
42b9715
a03292c
42b9715
 
a03292c
42b9715
a03292c
42b9715
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
from typing import Literal, Optional, Union

import lancedb
import pyarrow as pa
from lancedb.pydantic import LanceModel

qna_schema = pa.schema(
    [
        pa.field("uid", pa.string()),
        pa.field("vector", pa.list_(pa.float32(), 1536)),
        pa.field("question", pa.string()),
        pa.field("answer", pa.string()),
        pa.field("language", pa.string()),
        pa.field("source", pa.string()),
        pa.field("category", pa.string()),
    ]
)


class QnA(LanceModel):
    uid: str
    question: str
    answer: str
    language: str
    category: str
    source: str
    score: Optional[float] = None


class LanceVectorDb:
    def __init__(self, path):
        self.db = lancedb.connect(path)

        if "qna_table" not in self.db.table_names():
            self.table = self.db.create_table("qna_table", schema=qna_schema)
        else:
            self.table = self.db.open_table("qna_table")

    def init_from_qna_json(self, path):
        with open(path, encoding="utf-8") as f:
            qna_data = json.load(f)

        qnas = qna_data["qna"]
        embeddings = qna_data["embeddings"]

        qnas_with_embeddings = []
        for qna in qnas:
            uid = qna["uid"]
            emb = embeddings.get(uid)
            if emb is None:
                continue

            qna["vector"] = emb
            qnas_with_embeddings.append(qna)

        self.insert(qnas_with_embeddings)

    def insert(self, data: Union[dict, list[dict]]):
        if not isinstance(data, list):
            data = [data]
        # This step is temporary. They are working on fixing this.
        columns = list(data[0].keys())
        data_columns = {column: [d[column] for d in data] for column in columns}
        elements_to_insert = pa.Table.from_pydict(data_columns, schema=qna_schema)
        self.table.add(elements_to_insert)

    def get_qna(
        self,
        vector: list,
        metric: Literal["L2", "cosine"] = "L2",
        filters: Optional[dict] = {"language": "de"},
        limit=3,
    ):
        filters_string = " AND ".join([f"{k} == '{v}'" for k, v in filters.items()])
        print(filters_string)

        query = self.table.search(vector)

        if filters_string:
            query = query.where(filters_string)

        results = query.metric(metric).limit(limit).to_df().to_dict(orient="records")

        results = [QnA(**result, score=result["_distance"]) for result in results]
        return results