Jagpreet Singh commited on
Commit
ecacae5
2 Parent(s): 1cb46fc c0c01c6

Merge pull request #1 from SinghJagpreet096/b1

Browse files
Files changed (3) hide show
  1. app.py +60 -59
  2. requirements.txt +4 -4
  3. src/utils.py +30 -46
app.py CHANGED
@@ -1,99 +1,100 @@
1
  import os
2
  import logging
3
 
 
 
 
 
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.embeddings.openai import OpenAIEmbeddings
 
 
 
 
6
  import chainlit as cl
7
- from src.utils import get_docSearch, get_source
8
- from src.model import load_chain
9
 
 
 
 
 
 
 
10
 
11
- welcome_message = """ Upload your file here"""
12
 
13
  @cl.on_chat_start
14
  async def start():
15
- await cl.Message(content="you are in ").send()
16
- logging.info(f"app started")
17
  files = None
18
- while files is None:
19
- files = await cl.AskFileMessage(
20
- content=welcome_message,
21
- accept=["text/plain", "application/pdf"],
22
- max_size_mb=10,
23
- timeout=90
24
- ).send()
25
- logging.info("uploader excecuted")
26
- file = files[0]
27
- msg = cl.Message(content=f"Processing {file.name}....")
28
- await msg.send()
29
 
30
- logging.info("processing started")
31
 
32
- docsearch = get_docSearch(file,cl)
33
-
34
- logging.info("document uploaded success")
35
 
36
- chain = load_chain(docsearch)
 
37
 
38
- logging.info(f"Model loaded successfully")
39
 
40
 
41
- ## let the user know when system is ready
42
 
43
- msg.content = f"{file.name} processed. You begin asking questions"
44
- await msg.update()
45
 
46
- logging.info("processing completed")
 
 
 
 
 
47
 
48
- cl.user_session.set("chain", chain)
 
 
 
 
 
 
 
49
 
50
- logging.info("chain saved for active session")
51
 
52
- @cl.on_message
53
- async def main(message):
54
 
 
55
 
56
- chain = cl.user_session.get("chain")
57
 
58
- logging.info(f"retrived chain for QA {type(chain)}")
59
- cb = cl.AsyncLangchainCallbackHandler(
60
- stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
61
- )
62
-
63
- logging.info("define call backs")
64
 
65
 
66
- cb.answer_reached = True
67
- logging.info("answer reached")
68
 
69
- res = await chain.acall(message, callbacks=[cb])
70
- logging.info("define res")
71
 
 
72
 
73
- logging.info("call backs ")
74
 
 
75
 
 
76
 
77
  answer = res["answer"]
78
- sources = res["sources"].strip()
79
-
80
-
81
- ## get doc from user session
82
- docs = cl.user_session.get("docs")
83
- metadatas = [doc.metadata for doc in docs]
84
- all_sources = [m["source"]for m in metadatas]
85
 
86
 
 
 
 
 
 
87
 
88
- source_elements = get_source(sources,all_sources,docs,cl)
89
 
90
- logging.info("getting source")
91
 
92
- if cb.has_streamed_final_answer:
93
- cb.final_stream.elements = source_elements
94
- await cb.final_stream.update()
95
- logging.info("call back triggred")
96
- else:
97
- await cl.Message(content=answer, elements=source_elements).send()
98
- logging.info("post message")
99
 
 
1
  import os
2
  import logging
3
 
4
+ #pip install pypdf
5
+ #export HNSWLIB_NO_NATIVE = 1
6
+
7
+ from langchain.document_loaders import PyPDFDirectoryLoader, TextLoader
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+
10
+ from langchain.vectorstores import Chroma
11
+ from langchain.memory import ChatMessageHistory, ConversationBufferMemory
12
+ from langchain.chains import ConversationalRetrievalChain
13
+ from langchain.chat_models import ChatOpenAI
14
  import chainlit as cl
 
 
15
 
16
+ from src.utils import get_docsearch, get_source
17
+
18
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
19
+ # embeddings = OpenAIEmbeddings()
20
+
21
+ welcome_message = """Welcome"""
22
 
 
23
 
24
  @cl.on_chat_start
25
  async def start():
26
+ await cl.Message("test").send()
 
27
  files = None
28
+ files = await cl.AskFileMessage(
29
+ content=welcome_message,
30
+ accept=["text/plain", "application/pdf"],
31
+ ).send()
 
 
 
 
 
 
 
32
 
33
+ logging.info("file uploaded")
34
 
35
+ file = files[0]
 
 
36
 
37
+ msg = cl.Message(content=f"Processing {file.name}")
38
+ await msg.send()
39
 
40
+ logging.info("file processing")
41
 
42
 
43
+ docsearch = await cl.make_async(get_docsearch)(file)
44
 
45
+ message_history = ChatMessageHistory()
 
46
 
47
+ memory = ConversationBufferMemory(
48
+ memory_key="chat_history",
49
+ output_key="answer",
50
+ chat_memory=message_history,
51
+ return_messages=True
52
+ )
53
 
54
+ ## create chain that uses chroma vector store
55
+ chain = ConversationalRetrievalChain.from_llm(
56
+ ChatOpenAI(model_name="gpt-3.5-turbo",temperature=0, streaming=True),
57
+ chain_type="stuff",
58
+ retriever=docsearch.as_retriever(),
59
+ memory=memory,
60
+ return_source_documents=True,
61
+ )
62
 
 
63
 
64
+ msg.content = f"Processing {file.name} completed. Start asking questions!"
65
+ await msg.update()
66
 
67
+ logging.info("file processed success")
68
 
69
+ cl.user_session.set("chain",chain)
70
 
71
+ logging.info("saved chain in currrent session")
 
 
 
 
 
72
 
73
 
74
+ @cl.on_message
75
+ async def main(message: cl.Message):
76
 
77
+ ## get chain
78
+ chain = cl.user_session.get("chain")
79
 
80
+ logging.info("loaded chain")
81
 
82
+ cb = cl.AsyncLangchainCallbackHandler()
83
 
84
+ logging.info("loaded callbacks")
85
 
86
+ res = await chain.acall(message.content, callbacks=[cb])
87
 
88
  answer = res["answer"]
89
+ source_documents = res["source_documents"]
 
 
 
 
 
 
90
 
91
 
92
+
93
+ text_elements = get_source(answer, source_documents)
94
+ await cl.Message(content=answer, elements=text_elements).send()
95
+
96
+
97
 
 
98
 
 
99
 
 
 
 
 
 
 
 
100
 
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  langchain
2
- openai
3
- python-dotenv
4
  chainlit
 
5
  chromadb
6
- tiktoken
7
- tokenizers
 
1
  langchain
2
+ chroma
3
+ pypdf
4
  chainlit
5
+ openai
6
  chromadb
7
+ tiktoken
 
src/utils.py CHANGED
@@ -3,20 +3,19 @@ import click
3
  from langchain.document_loaders import TextLoader
4
  from langchain.document_loaders import PyPDFLoader
5
  from langchain.vectorstores import Chroma
 
 
 
6
 
7
 
8
  from src.config import Config
9
  import logging
10
 
11
- from dotenv import load_dotenv
12
-
13
- load_dotenv()
14
-
15
-
16
-
17
 
18
  def process_file(file: AskFileResponse):
19
- import tempfile
20
 
21
  if file.type == "text/plain":
22
  Loader = TextLoader
@@ -27,52 +26,37 @@ def process_file(file: AskFileResponse):
27
  tempfile.write(file.content)
28
  loader = Loader(tempfile.name)
29
  documents = loader.load()
30
- # text_splitter = text_splitter()
31
- docs = Config.text_splitter.split_documents(documents)
32
-
33
  for i, doc in enumerate(docs):
34
  doc.metadata["source"] = f"source_{i}"
35
  return docs
36
 
37
- def get_docSearch(file,cl):
38
  docs = process_file(file)
39
 
40
- logging.info("files loaded ")
 
41
 
42
- ## save data in user session
43
- cl.user_session.set("docs",docs)
44
-
45
- logging.info("docs saved in active session")
46
-
47
- docsearch = Chroma.from_documents(docs, Config.embeddings)
48
-
49
- logging.info(f"embedding completed {type(Config.embeddings)}")
50
-
51
- logging.info(f"type of docsearch {type(docsearch)}")
52
 
 
 
 
53
  return docsearch
54
 
55
- def get_source(sources,all_sources,docs,cl):
56
- answer = []
57
- source_elements = []
58
- if sources:
59
- found_sources = []
60
-
61
- # Add the sources to the message
62
- for source in sources.split(","):
63
- source_name = source.strip().replace(".", "")
64
- # Get the index of the source
65
- try:
66
- index = all_sources.index(source_name)
67
- except ValueError:
68
- continue
69
- text = docs[index].page_content
70
- found_sources.append(source_name)
71
- # Create the text element referenced in the message
72
- source_elements.append(cl.Text(content=text, name=source_name))
73
-
74
- if found_sources:
75
- answer += f"\nSources: {', '.join(found_sources)}"
76
- else:
77
- answer += "\nNo sources found"
78
- return source_elements,answer
 
3
  from langchain.document_loaders import TextLoader
4
  from langchain.document_loaders import PyPDFLoader
5
  from langchain.vectorstores import Chroma
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.embeddings.openai import OpenAIEmbeddings
8
+ import chainlit as cl
9
 
10
 
11
  from src.config import Config
12
  import logging
13
 
14
+ text_splitter = RecursiveCharacterTextSplitter()
15
+ embeddings = OpenAIEmbeddings()
 
 
 
 
16
 
17
  def process_file(file: AskFileResponse):
18
+ import tempfile
19
 
20
  if file.type == "text/plain":
21
  Loader = TextLoader
 
26
  tempfile.write(file.content)
27
  loader = Loader(tempfile.name)
28
  documents = loader.load()
29
+ docs = text_splitter.split_documents(documents)
 
 
30
  for i, doc in enumerate(docs):
31
  doc.metadata["source"] = f"source_{i}"
32
  return docs
33
 
34
+ def get_docsearch(file: AskFileResponse):
35
  docs = process_file(file)
36
 
37
+ # Save data in the user session
38
+ cl.user_session.set("docs", docs)
39
 
40
+ # Create a unique namespace for the file
 
 
 
 
 
 
 
 
 
41
 
42
+ docsearch = Chroma.from_documents(
43
+ docs, embeddings
44
+ )
45
  return docsearch
46
 
47
+ def get_source(answer,source_documents):
48
+ text_elements = []
49
+ if source_documents:
50
+ for source_idx, source_doc in enumerate(source_documents):
51
+ source_name = f"source_{source_idx}"
52
+
53
+ text_elements.append(
54
+ cl.Text(content=source_doc.page_content, name=source_name)
55
+ )
56
+ source_names = [text_el.name for text_el in text_elements]
57
+
58
+ if source_names:
59
+ answer += f"\nSources: {', '.join(source_names)}"
60
+ else:
61
+ answer += "\nNo source found"
62
+ return text_elements