|
import os |
|
import torch |
|
import uuid |
|
import requests |
|
import streamlit as st |
|
from streamlit.logger import get_logger |
|
from auto_gptq import AutoGPTQForCausalLM |
|
from langchain import HuggingFacePipeline, PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
from langchain.document_loaders import PyPDFDirectoryLoader |
|
from langchain.embeddings import HuggingFaceInstructEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from pdf2image import convert_from_path |
|
from transformers import AutoTokenizer, TextStreamer, pipeline |
|
from langchain.memory import ConversationBufferMemory |
|
from gtts import gTTS |
|
from io import BytesIO |
|
from langchain.chains import ConversationalRetrievalChain |
|
import streamlit.components.v1 as components |
|
from langchain.document_loaders import UnstructuredMarkdownLoader |
|
from langchain.vectorstores.utils import filter_complex_metadata |
|
import fitz |
|
from PIL import Image |
|
from langchain.vectorstores import FAISS |
|
import transformers |
|
from pydub import AudioSegment |
|
from streamlit_extras.stateful_button import button |
|
|
|
user_session_id = uuid.uuid4() |
|
|
|
logger = get_logger(__name__) |
|
st.set_page_config(page_title="Document QA by Dono", page_icon="🤖", ) |
|
st.session_state.disabled = False |
|
st.title("Document QA by Dono") |
|
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
@st.cache_data |
|
def load_data(): |
|
loader = PyPDFDirectoryLoader("/home/user/app/pdfs/") |
|
docs = loader.load() |
|
return docs |
|
|
|
@st.cache_resource |
|
def load_model(_docs): |
|
embeddings = HuggingFaceInstructEmbeddings(model_name="/home/user/app/all-MiniLM-L6-v2/",model_kwargs={"device":DEVICE}) |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256) |
|
texts = text_splitter.split_documents(docs) |
|
db = FAISS.from_documents(texts, embeddings) |
|
model_name_or_path = "/home/user/app/Llama-2-13B-chat-GPTQ/" |
|
model_basename = "model" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
|
|
model = AutoGPTQForCausalLM.from_quantized( |
|
model_name_or_path, |
|
revision="gptq-8bit-128g-actorder_False", |
|
model_basename=model_basename, |
|
use_safetensors=True, |
|
trust_remote_code=True, |
|
inject_fused_attention=False, |
|
device=DEVICE, |
|
quantize_config=None, |
|
) |
|
|
|
DEFAULT_SYSTEM_PROMPT = """ |
|
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. |
|
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. |
|
Please ensure that your responses are socially unbiased and positive in nature. |
|
Always provide the citation for the answer from the text. |
|
Try to include any section or subsection present in the text responsible for the answer. |
|
Provide reference. Provide page number, section, sub section etc. |
|
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. |
|
Given a government document that outlines rules and regulations for a specific industry or sector, use your language model to answer questions about the rules and their applicability over time. |
|
The document may include provisions that take effect at different times, such as immediately upon publication, after a grace period, or on a specific date in the future. |
|
Your task is to identify the relevant rules and determine when they go into effect, taking into account any dependencies or exceptions that may apply. |
|
The current date is 14 September, 2023. Try to extract information which is closer to this date. |
|
Take a deep breath and work on this problem step-by-step. |
|
""".strip() |
|
|
|
|
|
def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: |
|
return f"""[INST] <<SYS>>{system_prompt}<</SYS>>{prompt} [/INST]""".strip() |
|
|
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
text_pipeline = pipeline("text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
max_new_tokens=1024, |
|
temperature=0.2, |
|
top_p=0.95, |
|
repetition_penalty=1.15, |
|
streamer=streamer,) |
|
llm = HuggingFacePipeline(pipeline=text_pipeline, model_kwargs={"temperature": 0.2}) |
|
|
|
SYSTEM_PROMPT = ("Use the following pieces of context to answer the question at the end. " |
|
"If you don't know the answer, just say that you don't know, " |
|
"don't try to make up an answer.") |
|
|
|
template = generate_prompt("""{context} Question: {question} """,system_prompt=SYSTEM_PROMPT,) |
|
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm=llm, |
|
chain_type="stuff", |
|
retriever=db.as_retriever(search_kwargs={"k": 5}), |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": prompt, |
|
"verbose": False}) |
|
|
|
print('load done') |
|
return qa_chain |
|
|
|
|
|
model_name_or_path = "Llama-2-13B-chat-GPTQ" |
|
model_basename = "model" |
|
|
|
st.session_state["llm_model"] = model_name_or_path |
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
def on_select(): |
|
st.session_state.disabled = True |
|
|
|
|
|
def get_message_history(): |
|
for message in st.session_state.messages: |
|
role, content = message["role"], message["content"] |
|
yield f"{role.title()}: {content}" |
|
|
|
|
|
docs = load_data() |
|
qa_chain = load_model(docs) |
|
|
|
if prompt := st.chat_input("How can I help you today?"): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
with st.chat_message("assistant"): |
|
message_placeholder = st.empty() |
|
full_response = "" |
|
message_history = "\n".join(list(get_message_history())[-3:]) |
|
result = qa_chain(prompt) |
|
output = [result['result']] |
|
for item in output: |
|
full_response += item |
|
message_placeholder.markdown(full_response + "▌") |
|
message_placeholder.markdown(full_response) |
|
|
|
def generate_pdf(): |
|
page_number = int(result['source_documents'][0].metadata['page']) |
|
doc = fitz.open(str(result['source_documents'][0].metadata['source'])) |
|
text = str(result['source_documents'][0].page_content) |
|
if text != '': |
|
for page in doc: |
|
text_instances = page.search_for(text) |
|
for inst in text_instances: |
|
highlight = page.add_highlight_annot(inst) |
|
highlight.update() |
|
doc.save("/home/user/app/pdf2image/output.pdf", garbage=4, deflate=True, clean=True) |
|
|
|
def pdf_page_to_image(pdf_file, page_number, output_image): |
|
pdf_document = fitz.open(pdf_file) |
|
page = pdf_document[page_number] |
|
dpi = 300 |
|
pix = page.get_pixmap(matrix=fitz.Matrix(dpi / 100, dpi / 100)) |
|
pix.save(output_image, "png") |
|
pdf_document.close() |
|
pdf_page_to_image('/home/user/app/pdf2image/output.pdf', page_number, '/home/user/app/pdf2image/output.png') |
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_audio(): |
|
with open('/home/user/app/audio/audio.mp3','wb') as sound_file: |
|
tts = gTTS(result['result'], lang='en', tld='co.in') |
|
tts.write_to_fp(sound_file) |
|
sound = AudioSegment.from_mp3("/home/user/app/audio/audio.mp3") |
|
sound.export("/home/user/app/audio/audio.wav", format="wav") |
|
|
|
if "reference" not in st.session_state: |
|
st.session_state.reference = False |
|
if "audio" not in st.session_state: |
|
st.session_state.audio = False |
|
|
|
|
|
with st.sidebar: |
|
choice = st.radio("References and TTS",["Reference & TTS" ], index=None,) |
|
if choice == 'Reference & TTS': |
|
generate_pdf() |
|
st.session_state['reference'] = '/home/user/app/pdf2image/output.png' |
|
st.image(st.session_state['reference']) |
|
|
|
st.session_state.reference = True |
|
with open('/home/user/app/audio/audio.mp3','wb') as sound_file: |
|
tts = gTTS(result['result'], lang='en', tld = 'co.in') |
|
tts.write_to_fp(sound_file) |
|
sound = AudioSegment.from_mp3("/home/user/app/audio/audio.mp3") |
|
sound.export("/home/user/app/audio/audio.wav", format="wav") |
|
st.session_state['audio'] = '/home/user/app/audio/audio.wav' |
|
st.audio(st.session_state['audio']) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
|