|
import chromadb |
|
from datetime import datetime |
|
|
|
chroma_client = chromadb.Client() |
|
|
|
|
|
def get_or_create_collection(coll_name: str): |
|
date = coll_name[:6] |
|
coll = chroma_client.get_or_create_collection(name=coll_name, metadata={"date": date}) |
|
return coll |
|
|
|
|
|
def get_collection(coll_name: str): |
|
coll = chroma_client.get_collection(name=coll_name) |
|
return coll |
|
|
|
|
|
def reset_collection(coll_name: str): |
|
coll = chroma_client.get_collection(name=coll_name) |
|
coll.delete() |
|
return coll |
|
|
|
|
|
def delete_old_collections(old=2): |
|
collections = chroma_client.list_collections() |
|
current_hour = int(datetime.now().strftime("%m%d%H")) |
|
|
|
for coll in collections: |
|
coll_hour = int(coll.metadata['date']) |
|
if coll_hour < current_hour - old: |
|
chroma_client.delete_collection(coll.name) |
|
|
|
|
|
def add_texts_to_collection(coll_name: str, texts: [str], file: str, source: str): |
|
""" |
|
add texts to a collection : texts originate all from the same file |
|
""" |
|
coll = chroma_client.get_collection(name=coll_name) |
|
filenames = [{file: 1, 'source': source} for _ in texts] |
|
ids = [file+'-'+str(i) for i in range(len(texts))] |
|
try: |
|
coll.delete(ids=ids) |
|
coll.add(documents=texts, metadatas=filenames, ids=ids) |
|
except: |
|
print(f"exception raised for collection :{coll_name}, texts: {texts} from file {file} and source {source}") |
|
|
|
|
|
def delete_collection(coll_name: str): |
|
chroma_client.delete_collection(name=coll_name) |
|
|
|
|
|
def list_collections(): |
|
return chroma_client.list_collections() |
|
|
|
|
|
def query_collection(coll_name: str, query: str, from_files: [str], n_results: int = 4): |
|
assert 0 < len(from_files) |
|
coll = chroma_client.get_collection(name=coll_name) |
|
where_ = [{file: 1} for file in from_files] |
|
where_ = where_[0] if len(where_) == 1 else {'$or': where_} |
|
n_results_ = min(n_results, coll.count()) |
|
|
|
ans = "" |
|
try: |
|
ans = coll.query(query_texts=query, n_results=n_results_, where=where_) |
|
except: |
|
print(f"exception raised at query collection for collection {coll_name} and query {query} from files " |
|
f"{from_files}") |
|
|
|
return ans |
|
|