chahah commited on
Commit
e7dbb12
1 Parent(s): e556cb5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://python.langchain.com/docs/tutorials/rag/
2
+ import gradio as gr
3
+ from langchain import hub
4
+ from langchain_chroma import Chroma
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.runnables import RunnablePassthrough
7
+ from langchain_mistralai import MistralAIEmbeddings
8
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from langchain_mistralai import ChatMistralAI
11
+ from langchain_community.document_loaders import PyPDFLoader
12
+ import requests
13
+ from pathlib import Path
14
+ from langchain_community.document_loaders import WebBaseLoader
15
+ from langchain_community.retrievers import ArxivRetriever
16
+ import bs4
17
+ from langchain_core.rate_limiters import InMemoryRateLimiter
18
+ from urllib.parse import urljoin
19
+
20
+ rate_limiter = InMemoryRateLimiter(
21
+ requests_per_second=0.1, # <-- MistralAI free. We can only make a request once every second
22
+ check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
23
+ max_bucket_size=10, # Controls the maximum burst size.
24
+ )
25
+ """
26
+ # get data
27
+ urlsfile = open("urls.txt")
28
+ urls = urlsfile.readlines()
29
+ urls = [url.replace("\n","") for url in urls]
30
+ urlsfile.close()
31
+
32
+ # Load, chunk and index the contents of the blog.
33
+ loader = WebBaseLoader(urls)
34
+ docs = loader.load()
35
+
36
+ # load arxiv papers
37
+ arxivfile = open("arxiv.txt")
38
+ arxivs = arxivfile.readlines()
39
+ arxivs = [arxiv.replace("\n","") for arxiv in arxivs]
40
+ arxivfile.close()
41
+
42
+ retriever = ArxivRetriever(
43
+ load_max_docs=2,
44
+ get_ful_documents=True,
45
+ )
46
+
47
+ for arxiv in arxivs:
48
+ doc = retriever.invoke(arxiv)
49
+ doc[0].metadata['Published'] = str(doc[0].metadata['Published'])
50
+ docs.append(doc[0])
51
+
52
+
53
+ def format_docs(docs):
54
+ return "\n\n".join(doc.page_content for doc in docs)
55
+
56
+ def RAG(llm, docs, embeddings):
57
+
58
+ # Split text
59
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
60
+ splits = text_splitter.split_documents(docs)
61
+
62
+ # Create vector store
63
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
64
+
65
+ # Retrieve and generate using the relevant snippets of the documents
66
+ retriever = vectorstore.as_retriever()
67
+
68
+ # Prompt basis example for RAG systems
69
+ prompt = hub.pull("rlm/rag-prompt")
70
+
71
+ # Create the chain
72
+ rag_chain = (
73
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
74
+ | prompt
75
+ | llm
76
+ | StrOutputParser()
77
+ )
78
+
79
+ return rag_chain
80
+
81
+ # LLM model
82
+ llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
83
+
84
+ # Embeddings
85
+ embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
86
+ # embed_model = "nvidia/NV-Embed-v2"
87
+ embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
88
+ # embeddings = MistralAIEmbeddings()
89
+
90
+ # RAG chain
91
+ rag_chain = RAG(llm, docs, embeddings)
92
+
93
+ def handle_prompt(message, history):
94
+ try:
95
+ # Stream output
96
+ out=""
97
+ for chunk in rag_chain.stream(message):
98
+ out += chunk
99
+ yield out
100
+ except:
101
+ raise gr.Error("Requests rate limit exceeded")
102
+ """
103
+
104
+ def handle_prompt(message, history):
105
+ print(message)
106
+
107
+
108
+ greetingsmessage = "Hi, I'm your personal arXiv reader. Input the arXiv number of the paper:"
109
+
110
+ demo = gr.ChatInterface(handle_prompt, type="messages", title="ChangBot", theme=gr.themes.Soft(), description=greetingsmessage)
111
+
112
+ demo.launch()
113
+
114
+ example_questions = [
115
+ "Tell me more about SimBIG",
116
+ "How can you constrain neutrino mass with galaxies?",
117
+ "What is the DESI BGS?",
118
+ "What is SEDflow?",
119
+ "What are normalizing flows?"
120
+ ]