lindsay-qu
commited on
Update core/retriever/chroma_retriever.py
Browse files
core/retriever/chroma_retriever.py
CHANGED
@@ -31,6 +31,7 @@ class ChromaRetriever(BaseRetriever):
|
|
31 |
if not os.path.exists("persist"):
|
32 |
os.mkdir("persist")
|
33 |
client = PersistentClient(path="persist")
|
|
|
34 |
|
35 |
try:
|
36 |
collection = client.get_collection(name=collection_name)
|
@@ -41,8 +42,11 @@ class ChromaRetriever(BaseRetriever):
|
|
41 |
docs = pdf_loader.load()
|
42 |
|
43 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=split_args["size"], chunk_overlap=split_args["overlap"])
|
44 |
-
|
45 |
-
texts = [
|
|
|
|
|
|
|
46 |
|
47 |
collection = client.create_collection(name=collection_name)
|
48 |
if embed_model is not None:
|
@@ -50,12 +54,14 @@ class ChromaRetriever(BaseRetriever):
|
|
50 |
collection.add(
|
51 |
embeddings=embeddings,
|
52 |
documents=texts,
|
53 |
-
ids=[str(i+1) for i in range(len(texts))]
|
|
|
54 |
)
|
55 |
else:
|
56 |
collection.add(
|
57 |
documents=texts,
|
58 |
-
ids=[str(i+1) for i in range(len(texts))]
|
|
|
59 |
)
|
60 |
|
61 |
self.collection = collection
|
@@ -82,4 +88,4 @@ class ChromaRetriever(BaseRetriever):
|
|
82 |
query_texts=[query],
|
83 |
n_results=k,
|
84 |
)
|
85 |
-
return results['documents'][0]
|
|
|
31 |
if not os.path.exists("persist"):
|
32 |
os.mkdir("persist")
|
33 |
client = PersistentClient(path="persist")
|
34 |
+
print(client.list_collections())
|
35 |
|
36 |
try:
|
37 |
collection = client.get_collection(name=collection_name)
|
|
|
42 |
docs = pdf_loader.load()
|
43 |
|
44 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=split_args["size"], chunk_overlap=split_args["overlap"])
|
45 |
+
split_docs = text_splitter.split_documents(docs)
|
46 |
+
texts = [doc.page_content for doc in split_docs]
|
47 |
+
|
48 |
+
# TODO
|
49 |
+
titles = [doc.metadata["title"] for doc in split_docs]
|
50 |
|
51 |
collection = client.create_collection(name=collection_name)
|
52 |
if embed_model is not None:
|
|
|
54 |
collection.add(
|
55 |
embeddings=embeddings,
|
56 |
documents=texts,
|
57 |
+
ids=[str(i+1) for i in range(len(texts))],
|
58 |
+
metadatas=[{"title": title} for title in titles]
|
59 |
)
|
60 |
else:
|
61 |
collection.add(
|
62 |
documents=texts,
|
63 |
+
ids=[str(i+1) for i in range(len(texts))],
|
64 |
+
metadatas=[{"title": title} for title in titles]
|
65 |
)
|
66 |
|
67 |
self.collection = collection
|
|
|
88 |
query_texts=[query],
|
89 |
n_results=k,
|
90 |
)
|
91 |
+
return results['documents'][0], [result["title"] for result in results['metadatas'][0]]
|