File size: 3,473 Bytes
e7dbb12
 
 
 
 
 
 
 
 
 
 
 
 
9db9263
e7dbb12
 
 
 
 
d3d9ad2
 
 
 
 
 
 
9db9263
d3d9ad2
 
 
 
 
8f13397
d3d9ad2
8f13397
d3d9ad2
 
 
bde2b54
d3d9ad2
 
 
 
 
 
 
 
bde2b54
d3d9ad2
 
 
 
 
 
 
 
 
 
 
bde2b54
b38713e
e7dbb12
 
 
 
 
 
 
 
 
 
bde2b54
e7dbb12
d3d9ad2
bde2b54
d3d9ad2
 
 
52a781c
d3d9ad2
 
 
 
 
 
 
 
 
 
 
 
 
 
b38713e
d3d9ad2
bde2b54
b38713e
bde2b54
 
d3d9ad2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# https://python.langchain.com/docs/tutorials/rag/
import gradio as gr
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_mistralai import MistralAIEmbeddings
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_mistralai import ChatMistralAI
from langchain_community.document_loaders import PyPDFLoader
import requests
from pathlib import Path
from langchain_community.document_loaders import WebBaseLoader, ArxivLoader
import bs4
from langchain_core.rate_limiters import InMemoryRateLimiter
from urllib.parse import urljoin


# LLM model
rate_limiter = InMemoryRateLimiter(
    requests_per_second=0.1,  # <-- MistralAI free. We can only make a request once every second
    check_every_n_seconds=0.01,  # Wake up every 100 ms to check whether allowed to make a request,
    max_bucket_size=10,  # Controls the maximum burst size.
)    
llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)

# Embeddings
embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
# embed_model = "nvidia/NV-Embed-v2"
embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
# embeddings = MistralAIEmbeddings()

def RAG(llm, docs, embeddings):

    # Split text
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    splits = text_splitter.split_documents(docs)

    # Create vector store
    vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)

    # Retrieve and generate using the relevant snippets of the documents
    retriever = vectorstore.as_retriever()

    # Prompt basis example for RAG systems
    prompt = hub.pull("rlm/rag-prompt")

    # Create the chain
    rag_chain = (
        {"context": retriever | format_docs, "question": RunnablePassthrough()}
        | prompt
        | llm
        | StrOutputParser()
    )
    return rag_chain

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

def handle_prompt(message, history, arxivcode, rag_chain): 
    try:
        # Stream output
        out=""
        for chunk in rag_chain.stream(message):
            out += chunk
            yield out
    except:
        raise gr.Error("Requests rate limit exceeded")


greetingsmessage = "Hi, I'm your personal arXiv reader. Ask me questions about the arXiv paper above"


with gr.Blocks() as demo:     
    arxiv_code = gr.Textbox("", label="arxiv.number")
    
    #rag_chain = initialize(arxiv_code)
    loader = ArxivLoader(query=str(arxiv_code),)
    docs = loader.load()
    #retriever = ArxivRetriever(
    #    load_max_docs=2,
    #    get_full_documents=True,
    #)
    #docs = retriever.invoke(str(arxivcode))
    #for i in range(len(docs)): 
    #    docs[i].metadata['Published'] = str(docs[i].metadata['Published'])
    
    # Load, chunk and index the contents of the blog.
    #url = ['https://arxiv.org/abs/%s' % arxivcode]
    #loader = WebBaseLoader(url)
    #docs = loader.load()
    rag_chain = RAG(llm, docs, embeddings)
    
    gr.ChatInterface(handle_prompt, type="messages", theme=gr.themes.Soft(), 
                          description=greetingsmessage, 
                   additional_inputs=[arxiv_code, rag_chain]
                  )
                          
demo.launch()