|
import os |
|
import torch |
|
import uuid |
|
import requests |
|
import streamlit as st |
|
from streamlit.logger import get_logger |
|
from auto_gptq import AutoGPTQForCausalLM |
|
from langchain import HuggingFacePipeline, PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
from langchain.document_loaders import PyPDFDirectoryLoader |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from pdf2image import convert_from_path |
|
from transformers import AutoTokenizer, TextStreamer, pipeline |
|
from langchain.memory import ConversationBufferMemory |
|
from gtts import gTTS |
|
from io import BytesIO |
|
from langchain.chains import ConversationalRetrievalChain |
|
import streamlit.components.v1 as components |
|
from langchain.document_loaders import UnstructuredMarkdownLoader |
|
from langchain.vectorstores.utils import filter_complex_metadata |
|
import fitz |
|
from PIL import Image |
|
from langchain.vectorstores import FAISS |
|
import transformers |
|
from pydub import AudioSegment |
|
from streamlit_extras.streaming_write import write |
|
import time |
|
|
|
import transformers |
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast |
|
translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
translation_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt") |
|
|
|
|
|
def english_to_hindi(sentence): |
|
translation_tokenizer.src_lang = "en_xx" |
|
encoded_hi = translation_tokenizer(sentence, return_tensors="pt") |
|
generated_tokens = translation_model.generate(**encoded_hi, forced_bos_token_id=translation_tokenizer.lang_code_to_id["hi_IN"] ) |
|
return (translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)) |
|
|
|
|
|
def hindi_to_english(sentence): |
|
translation_tokenizer.src_lang = "hi_IN" |
|
encoded_hi = translation_tokenizer(sentence, return_tensors="pt") |
|
generated_tokens = translation_model.generate(**encoded_hi, forced_bos_token_id=translation_tokenizer.lang_code_to_id["en_XX"] ) |
|
return (translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)) |
|
|
|
|
|
user_session_id = uuid.uuid4() |
|
|
|
logger = get_logger(__name__) |
|
st.set_page_config(page_title="Document QA by Dono", page_icon="🤖", ) |
|
st.session_state.disabled = False |
|
st.title("Document QA by Dono") |
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
loader = PyPDFDirectoryLoader("/home/user/app/niti/") |
|
docs = loader.load() |
|
return docs |
|
|
|
@st.cache_resource |
|
def load_model(_docs): |
|
embeddings = HuggingFaceInstructEmbeddings(model_name="/home/user/app/all-MiniLM-L6-v2/",model_kwargs={"device":DEVICE}) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256) |
|
texts = text_splitter.split_documents(docs) |
|
db = FAISS.from_documents(texts, embeddings) |
|
|
|
|
|
|
|
model_basename = "model" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
|
|
model = AutoGPTQForCausalLM.from_quantized( |
|
model_name_or_path, |
|
|
|
revision="gptq-8bit-128g-actorder_True", |
|
model_basename=model_basename, |
|
use_safetensors=True, |
|
trust_remote_code=True, |
|
inject_fused_attention=False, |
|
device=DEVICE, |
|
quantize_config=None, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """ |
|
You are a helpful, respectful and honest assistant with knowledge of machine learning, data science, computer science, Python programming language, mathematics, probability and statistics. |
|
""".strip() |
|
|
|
def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: |
|
return f"""[INST] <<SYS>>{system_prompt}<</SYS>>{prompt} [/INST]""".strip() |
|
|
|
|
|
|
|
|
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
text_pipeline = pipeline("text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=1024, |
|
temperature=0.1, |
|
top_p=0.95, |
|
repetition_penalty=1.15, |
|
streamer=streamer,) |
|
llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.1}) |
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = ("Use the following pieces of context along with general information you possess to answer the question at the end." |
|
"If you don't know the answer, just say that you don't know, " |
|
"don't try to make up an answer.") |
|
|
|
template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) |
|
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=db.as_retriever(search_kwargs={"k": 3}), |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": prompt, |
|
"verbose": True}) |
|
|
|
print('load done') |
|
return qa_chain |
|
|
|
|
|
model_name_or_path = "Llama-2-13B-chat-GPTQ" |
|
model_basename = "model" |
|
|
|
st.session_state["llm_model"] = model_name_or_path |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
def on_select(): |
|
st.session_state.disabled = True |
|
|
|
|
|
def get_message_history(): |
|
for message in st.session_state.messages: |
|
role, content = message["role"], message["content"] |
|
yield f"{role.title()}: {content}" |
|
|
|
|
|
docs = load_data() |
|
qa_chain = load_model(docs) |
|
|
|
if prompt := st.chat_input("How can I help you today?"): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
english_prompt = hindi_to_english(prompt)[0] |
|
st.markdown(english_prompt) |
|
with st.chat_message("assistant"): |
|
with st.spinner(text="Looking for relevant answer"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
message_history = "\n".join(list(get_message_history())[-3:]) |
|
result = qa_chain(english_prompt) |
|
output = [result['result']] |
|
|
|
def generate_pdf(): |
|
page_number = int(result['source_documents'][0].metadata['page']) |
|
doc = fitz.open(str(result['source_documents'][0].metadata['source'])) |
|
text = str(result['source_documents'][0].page_content) |
|
if text != '': |
|
for page in doc: |
|
text_instances = page.search_for(text) |
|
for inst in text_instances: |
|
highlight = page.add_highlight_annot(inst) |
|
highlight.update() |
|
doc.save("/home/user/app/pdf2image/output.pdf", garbage=4, deflate=True, clean=True) |
|
|
|
def pdf_page_to_image(pdf_file, page_number, output_image): |
|
pdf_document = fitz.open(pdf_file) |
|
page = pdf_document[page_number] |
|
dpi = 300 |
|
pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100)) |
|
pix.save(output_image, "png") |
|
pdf_document.close() |
|
pdf_page_to_image('/home/user/app/pdf2image/output.pdf', page_number, '/home/user/app/pdf2image/output.png') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state['reference'] = '/home/user/app/pdf2image/default_output.png' |
|
st.session_state['audio'] = '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for item in output: |
|
full_response += item |
|
message_placeholder.markdown(full_response + "▌") |
|
message_placeholder.markdown(full_response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "reference" not in st.session_state: |
|
st.session_state.reference = False |
|
if "audio" not in st.session_state: |
|
st.session_state.audio = False |
|
|
|
|
|
with st.sidebar: |
|
choice = st.radio("References",["Reference"]) |
|
|
|
if choice == 'Reference': |
|
generate_pdf() |
|
st.session_state['reference'] = '/home/user/app/pdf2image/output.png' |
|
st.image(st.session_state['reference']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
|