chat-with-docs / app.py
Mattral's picture
Update app.py
f0dc35b verified
raw
history blame
4.13 kB
import gradio as gr
from gradio_pdf import PDF
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from transformers import AutoTokenizer
from langchain.vectorstores import Qdrant
from qdrant_client.http import models
from ctransformers import AutoModelForCausalLM
# Loading the embedding model
encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1')
print("Embedding model loaded...")
# Loading the LLM
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
'''
llm = AutoModelForCausalLM.from_pretrained(
"refuelai/Llama-3-Refueled",
model_file="llama-2-7b-chat.Q3_K_S.gguf",
model_type="llama",
temperature=0.2,
repetition_penalty=1.5,
max_new_tokens=300,
)
'''
model_id = "refuelai/Llama-3-Refueled"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
print("LLM loaded...")
def chat(files, question):
def get_chunks(text):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=250,
chunk_overlap=50,
length_function=len,
)
chunks = text_splitter.split_text(text)
return chunks
all_chunks = []
for file in files:
pdf_path = file
reader = PdfReader(pdf_path)
text = ""
num_of_pages = len(reader.pages)
for page in range(num_of_pages):
current_page = reader.pages[page]
text += current_page.extract_text()
chunks = get_chunks(text)
all_chunks.extend(chunks)
print(f"Total chunks: {len(all_chunks)}")
print("Chunks are ready...")
client = QdrantClient(path="./db")
print("DB created...")
client.recreate_collection(
collection_name="my_facts",
vectors_config=models.VectorParams(
size=encoder.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
)
print("Collection created...")
li = list(range(len(all_chunks)))
dic = dict(zip(li, all_chunks))
client.upload_records(
collection_name="my_facts",
records=[
models.Record(
id=idx,
vector=encoder.encode(dic[idx]).tolist(),
payload={f"chunk_{idx}": dic[idx]}
) for idx in dic.keys()
],
)
print("Records uploaded...")
hits = client.search(
collection_name="my_facts",
query_vector=encoder.encode(question).tolist(),
limit=3
)
context = []
for hit in hits:
context.append(list(hit.payload.values())[0])
context = " ".join(context)
system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions.
Read the given context before answering questions and think step by step. If you cannot answer a user question based on
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question."""
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
instruction = f"""
Context: {context}
User: {question}"""
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
print(prompt_template)
result = llm(prompt_template)
return result
screen = gr.Interface(
fn=chat,
inputs=[gr.File(label="Upload PDFs", file_count="multiple"), gr.Textbox(lines=10, placeholder="Enter your question here πŸ‘‰")],
outputs=gr.Textbox(lines=10, placeholder="Your answer will be here soon πŸš€"),
title="Q&A with PDFs πŸ‘©πŸ»β€πŸ’»πŸ““βœπŸ»πŸ’‘",
description="This app facilitates a conversation with PDFs uploadedπŸ’‘",
theme="soft",
)
screen.launch()