chahah commited on
Commit
d3d9ad2
1 Parent(s): 0f7043c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -60
app.py CHANGED
@@ -17,66 +17,46 @@ from langchain_core.rate_limiters import InMemoryRateLimiter
17
  from urllib.parse import urljoin
18
 
19
 
20
- def initialize(arxivcode):
21
- #loader = ArxivLoader(query=str(arxivcode),)
22
- #docs = loader.load()
23
- #retriever = ArxivRetriever(
24
- # load_max_docs=2,
25
- # get_full_documents=True,
26
- #)
27
- #docs = retriever.invoke(str(arxivcode))
28
- #for i in range(len(docs)):
29
- # docs[i].metadata['Published'] = str(docs[i].metadata['Published'])
30
 
31
- # Load, chunk and index the contents of the blog.
32
- url = ['https://arxiv.org/abs/%s' % arxivcode]
33
- loader = WebBaseLoader(url)
34
- docs = loader.load()
 
35
 
 
36
 
37
- # LLM model
38
- rate_limiter = InMemoryRateLimiter(
39
- requests_per_second=0.1, # <-- MistralAI free. We can only make a request once every second
40
- check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
41
- max_bucket_size=10, # Controls the maximum burst size.
42
- )
43
- llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
44
-
45
- # Embeddings
46
- embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
47
- # embed_model = "nvidia/NV-Embed-v2"
48
- embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
49
- # embeddings = MistralAIEmbeddings()
50
 
51
- def format_docs(docs):
52
- return "\n\n".join(doc.page_content for doc in docs)
53
-
54
- def RAG(llm, docs, embeddings):
55
-
56
- # Split text
57
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
58
- splits = text_splitter.split_documents(docs)
59
-
60
- # Create vector store
61
- vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
62
-
63
- # Retrieve and generate using the relevant snippets of the documents
64
- retriever = vectorstore.as_retriever()
65
-
66
- # Prompt basis example for RAG systems
67
- prompt = hub.pull("rlm/rag-prompt")
68
-
69
- # Create the chain
70
- rag_chain = (
71
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
72
- | prompt
73
- | llm
74
- | StrOutputParser()
75
- )
76
-
77
- return rag_chain
78
 
79
- return RAG(llm, docs, embeddings)
 
 
 
 
 
 
 
 
 
 
80
 
81
  def handle_prompt(message, history, arxivcode, rag_chain):
82
  try:
@@ -91,14 +71,30 @@ def handle_prompt(message, history, arxivcode, rag_chain):
91
 
92
  greetingsmessage = "Hi, I'm your personal arXiv reader. Ask me questions about the arXiv paper above"
93
 
 
94
  with gr.Blocks() as demo:
95
- arxiv_code = gr.Textbox("", label="arxiv.number")
96
- rag_chain = initialize(arxiv_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- gr.ChatInterface(handle_prompt, type="messages", theme=gr.themes.Soft(),
99
  description=greetingsmessage,
100
  additional_inputs=[arxiv_code, rag_chain]
101
  )
102
 
103
- if __name__=='__main__':
104
- demo.launch()
 
17
  from urllib.parse import urljoin
18
 
19
 
20
+ # LLM model
21
+ rate_limiter = InMemoryRateLimiter(
22
+ requests_per_second=0.1, # <-- MistralAI free. We can only make a request once every second
23
+ check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
24
+ max_bucket_size=10, # Controls the maximum burst size.
25
+ )
26
+ llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
 
 
 
27
 
28
+ # Embeddings
29
+ embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
30
+ # embed_model = "nvidia/NV-Embed-v2"
31
+ embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
32
+ # embeddings = MistralAIEmbeddings()
33
 
34
+ def RAG(llm, docs, embeddings):
35
 
36
+ # Split text
37
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
38
+ splits = text_splitter.split_documents(docs)
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Create vector store
41
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
42
+
43
+ # Retrieve and generate using the relevant snippets of the documents
44
+ retriever = vectorstore.as_retriever()
45
+
46
+ # Prompt basis example for RAG systems
47
+ prompt = hub.pull("rlm/rag-prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Create the chain
50
+ rag_chain = (
51
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
52
+ | prompt
53
+ | llm
54
+ | StrOutputParser()
55
+ )
56
+ return rag_chain
57
+
58
+ def format_docs(docs):
59
+ return "\n\n".join(doc.page_content for doc in docs)
60
 
61
  def handle_prompt(message, history, arxivcode, rag_chain):
62
  try:
 
71
 
72
  greetingsmessage = "Hi, I'm your personal arXiv reader. Ask me questions about the arXiv paper above"
73
 
74
+
75
  with gr.Blocks() as demo:
76
+ arxiv_code = gr.Textbox("", label="arxiv.number")
77
+
78
+ #rag_chain = initialize(arxiv_code)
79
+ loader = ArxivLoader(query=str(arxivcode),)
80
+ docs = loader.load()
81
+ #retriever = ArxivRetriever(
82
+ # load_max_docs=2,
83
+ # get_full_documents=True,
84
+ #)
85
+ #docs = retriever.invoke(str(arxivcode))
86
+ #for i in range(len(docs)):
87
+ # docs[i].metadata['Published'] = str(docs[i].metadata['Published'])
88
+
89
+ # Load, chunk and index the contents of the blog.
90
+ #url = ['https://arxiv.org/abs/%s' % arxivcode]
91
+ #loader = WebBaseLoader(url)
92
+ #docs = loader.load()
93
+ rag_chain = RAG(llm, docs, embeddings)
94
 
95
+ gr.ChatInterface(handle_prompt, type="messages", theme=gr.themes.Soft(),
96
  description=greetingsmessage,
97
  additional_inputs=[arxiv_code, rag_chain]
98
  )
99
 
100
+ demo.launch()