Spaces:
Runtime error
Runtime error
File size: 3,473 Bytes
e7dbb12 9db9263 e7dbb12 d3d9ad2 9db9263 d3d9ad2 8f13397 d3d9ad2 8f13397 d3d9ad2 bde2b54 d3d9ad2 bde2b54 d3d9ad2 bde2b54 b38713e e7dbb12 bde2b54 e7dbb12 d3d9ad2 bde2b54 d3d9ad2 52a781c d3d9ad2 b38713e d3d9ad2 bde2b54 b38713e bde2b54 d3d9ad2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# https://python.langchain.com/docs/tutorials/rag/
import gradio as gr
from langchain import hub
from langchain_chroma import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_mistralai import MistralAIEmbeddings
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_mistralai import ChatMistralAI
from langchain_community.document_loaders import PyPDFLoader
import requests
from pathlib import Path
from langchain_community.document_loaders import WebBaseLoader, ArxivLoader
import bs4
from langchain_core.rate_limiters import InMemoryRateLimiter
from urllib.parse import urljoin
# LLM model
rate_limiter = InMemoryRateLimiter(
requests_per_second=0.1, # <-- MistralAI free. We can only make a request once every second
check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
max_bucket_size=10, # Controls the maximum burst size.
)
llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
# Embeddings
embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
# embed_model = "nvidia/NV-Embed-v2"
embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
# embeddings = MistralAIEmbeddings()
def RAG(llm, docs, embeddings):
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create vector store
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
# Retrieve and generate using the relevant snippets of the documents
retriever = vectorstore.as_retriever()
# Prompt basis example for RAG systems
prompt = hub.pull("rlm/rag-prompt")
# Create the chain
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
return rag_chain
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def handle_prompt(message, history, arxivcode, rag_chain):
try:
# Stream output
out=""
for chunk in rag_chain.stream(message):
out += chunk
yield out
except:
raise gr.Error("Requests rate limit exceeded")
greetingsmessage = "Hi, I'm your personal arXiv reader. Ask me questions about the arXiv paper above"
with gr.Blocks() as demo:
arxiv_code = gr.Textbox("", label="arxiv.number")
#rag_chain = initialize(arxiv_code)
loader = ArxivLoader(query=str(arxiv_code),)
docs = loader.load()
#retriever = ArxivRetriever(
# load_max_docs=2,
# get_full_documents=True,
#)
#docs = retriever.invoke(str(arxivcode))
#for i in range(len(docs)):
# docs[i].metadata['Published'] = str(docs[i].metadata['Published'])
# Load, chunk and index the contents of the blog.
#url = ['https://arxiv.org/abs/%s' % arxivcode]
#loader = WebBaseLoader(url)
#docs = loader.load()
rag_chain = RAG(llm, docs, embeddings)
gr.ChatInterface(handle_prompt, type="messages", theme=gr.themes.Soft(),
description=greetingsmessage,
additional_inputs=[arxiv_code, rag_chain]
)
demo.launch() |