File size: 2,169 Bytes
4cf88e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import chromadb
from datetime import datetime

chroma_client = chromadb.Client()


def get_or_create_collection(coll_name: str):
    date = coll_name[:6]
    coll = chroma_client.get_or_create_collection(name=coll_name, metadata={"date": date})
    return coll


def get_collection(coll_name: str):
    coll = chroma_client.get_collection(name=coll_name)
    return coll


def reset_collection(coll_name: str):
    coll = chroma_client.get_collection(name=coll_name)
    coll.delete()
    return coll


def delete_old_collections(old=2):
    collections = chroma_client.list_collections()
    current_hour = int(datetime.now().strftime("%m%d%H"))

    for coll in collections:
        coll_hour = int(coll.metadata['date'])
        if coll_hour < current_hour - old:
            chroma_client.delete_collection(coll.name)


def add_texts_to_collection(coll_name: str, texts: [str], file: str, source: str):
    """
    add texts to a collection : texts originate all from the same file
    """
    coll = chroma_client.get_collection(name=coll_name)
    filenames = [{file: 1, 'source': source} for _ in texts]
    ids = [file+'-'+str(i) for i in range(len(texts))]
    try:
        coll.delete(ids=ids)
        coll.add(documents=texts, metadatas=filenames, ids=ids)
    except:
        print(f"exception raised for collection :{coll_name}, texts: {texts} from file {file} and source {source}")


def delete_collection(coll_name: str):
    chroma_client.delete_collection(name=coll_name)


def list_collections():
    return chroma_client.list_collections()


def query_collection(coll_name: str, query: str,  from_files: [str], n_results: int = 4):
    assert 0 < len(from_files)
    coll = chroma_client.get_collection(name=coll_name)
    where_ = [{file: 1} for file in from_files]
    where_ = where_[0] if len(where_) == 1 else {'$or': where_}
    n_results_ = min(n_results, coll.count())

    ans = ""
    try:
        ans = coll.query(query_texts=query, n_results=n_results_, where=where_)
    except:
        print(f"exception raised at query collection for collection {coll_name} and query {query} from files "
              f"{from_files}")

    return ans