Spaces:
Sleeping
Sleeping
File size: 4,130 Bytes
0b8e4b8 f0dc35b 0b8e4b8 11a69e0 0b8e4b8 11a69e0 0b8e4b8 11a69e0 0b8e4b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
from gradio_pdf import PDF
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from transformers import AutoTokenizer
from langchain.vectorstores import Qdrant
from qdrant_client.http import models
from ctransformers import AutoModelForCausalLM
# Loading the embedding model
encoder = SentenceTransformer('jinaai/jina-embedding-b-en-v1')
print("Embedding model loaded...")
# Loading the LLM
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
'''
llm = AutoModelForCausalLM.from_pretrained(
"refuelai/Llama-3-Refueled",
model_file="llama-2-7b-chat.Q3_K_S.gguf",
model_type="llama",
temperature=0.2,
repetition_penalty=1.5,
max_new_tokens=300,
)
'''
model_id = "refuelai/Llama-3-Refueled"
tokenizer = AutoTokenizer.from_pretrained(model_id)
llm = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
print("LLM loaded...")
def chat(files, question):
def get_chunks(text):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=250,
chunk_overlap=50,
length_function=len,
)
chunks = text_splitter.split_text(text)
return chunks
all_chunks = []
for file in files:
pdf_path = file
reader = PdfReader(pdf_path)
text = ""
num_of_pages = len(reader.pages)
for page in range(num_of_pages):
current_page = reader.pages[page]
text += current_page.extract_text()
chunks = get_chunks(text)
all_chunks.extend(chunks)
print(f"Total chunks: {len(all_chunks)}")
print("Chunks are ready...")
client = QdrantClient(path="./db")
print("DB created...")
client.recreate_collection(
collection_name="my_facts",
vectors_config=models.VectorParams(
size=encoder.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE,
),
)
print("Collection created...")
li = list(range(len(all_chunks)))
dic = dict(zip(li, all_chunks))
client.upload_records(
collection_name="my_facts",
records=[
models.Record(
id=idx,
vector=encoder.encode(dic[idx]).tolist(),
payload={f"chunk_{idx}": dic[idx]}
) for idx in dic.keys()
],
)
print("Records uploaded...")
hits = client.search(
collection_name="my_facts",
query_vector=encoder.encode(question).tolist(),
limit=3
)
context = []
for hit in hits:
context.append(list(hit.payload.values())[0])
context = " ".join(context)
system_prompt = """You are a helpful co-worker, you will use the provided context to answer user questions.
Read the given context before answering questions and think step by step. If you cannot answer a user question based on
the provided context, inform the user. Do not use any other information for answering user. Provide a detailed answer to the question."""
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
instruction = f"""
Context: {context}
User: {question}"""
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
print(prompt_template)
result = llm(prompt_template)
return result
screen = gr.Interface(
fn=chat,
inputs=[gr.File(label="Upload PDFs", file_count="multiple"), gr.Textbox(lines=10, placeholder="Enter your question here ๐")],
outputs=gr.Textbox(lines=10, placeholder="Your answer will be here soon ๐"),
title="Q&A with PDFs ๐ฉ๐ปโ๐ป๐โ๐ป๐ก",
description="This app facilitates a conversation with PDFs uploaded๐ก",
theme="soft",
)
screen.launch()
|