|
|
|
from typing import Any |
|
import gradio as gr |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain_community.vectorstores import Chroma |
|
|
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_openai import ChatOpenAI |
|
|
|
from langchain_community.document_loaders import PyMuPDFLoader |
|
|
|
import fitz |
|
from PIL import Image |
|
import os |
|
import re |
|
import openai |
|
|
|
openai.api_key = "sk-proj-9Wyp6JPj6G0QDIENCob2T3BlbkFJR2y4KpNGyS9C1KamMm9S" |
|
|
|
|
|
def add_text(history, text: str): |
|
if not text: |
|
raise gr.Error("Enter text") |
|
history = history + [(text, "")] |
|
return history |
|
|
|
|
|
class MyApp: |
|
def __init__(self) -> None: |
|
self.OPENAI_API_KEY: str = openai.api_key |
|
self.chain = None |
|
self.chat_history: list = [] |
|
self.N: int = 0 |
|
self.count: int = 0 |
|
|
|
def __call__(self, file: str) -> Any: |
|
if self.count == 0: |
|
self.chain = self.build_chain(file) |
|
self.count += 1 |
|
return self.chain |
|
|
|
def process_file(self, file: str): |
|
loader = PyMuPDFLoader(file.name) |
|
documents = loader.load() |
|
pattern = r"/([^/]+)$" |
|
match = re.search(pattern, file.name) |
|
try: |
|
file_name = match.group(1) |
|
except: |
|
file_name = os.path.basename(file) |
|
|
|
return documents, file_name |
|
|
|
def build_chain(self, file: str): |
|
documents, file_name = self.process_file(file) |
|
|
|
embeddings = OpenAIEmbeddings(openai_api_key=self.OPENAI_API_KEY) |
|
pdfsearch = Chroma.from_documents( |
|
documents, |
|
embeddings, |
|
collection_name=file_name, |
|
) |
|
chain = ConversationalRetrievalChain.from_llm( |
|
ChatOpenAI(temperature=0.0, openai_api_key=self.OPENAI_API_KEY), |
|
retriever=pdfsearch.as_retriever(search_kwargs={"k": 1}), |
|
return_source_documents=True, |
|
) |
|
return chain |
|
|
|
|
|
def get_response(history, query, file): |
|
if not file: |
|
raise gr.Error(message="Upload a PDF") |
|
chain = app(file) |
|
result = chain( |
|
{"question": query, "chat_history": app.chat_history}, return_only_outputs=True |
|
) |
|
app.chat_history += [(query, result["answer"])] |
|
app.N = list(result["source_documents"][0])[1][1]["page"] |
|
for char in result["answer"]: |
|
history[-1][-1] += char |
|
yield history, "" |
|
|
|
|
|
def render_file(file): |
|
doc = fitz.open(file.name) |
|
page = doc[app.N] |
|
|
|
pix = page.get_pixmap(dpi=150) |
|
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return image |
|
|
|
|
|
def purge_chat_and_render_first(file): |
|
print("purge_chat_and_render_first") |
|
|
|
app.chat_history = [] |
|
app.count = 0 |
|
|
|
|
|
doc = fitz.open(file.name) |
|
page = doc[0] |
|
|
|
pix = page.get_pixmap(dpi=150) |
|
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return image, [] |
|
|
|
|
|
app = MyApp() |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
chatbot = gr.Chatbot(value=[], elem_id="chatbot") |
|
with gr.Row(): |
|
txt = gr.Textbox( |
|
show_label=False, |
|
placeholder="Enter text and press submit", |
|
scale=2 |
|
) |
|
submit_btn = gr.Button("Submit", scale=1) |
|
|
|
with gr.Column(scale=1): |
|
with gr.Row(): |
|
show_img = gr.Image(label="Upload PDF") |
|
with gr.Row(): |
|
btn = gr.UploadButton("📁 Upload a PDF", file_types=[".pdf"]) |
|
|
|
btn.upload( |
|
fn=purge_chat_and_render_first, |
|
inputs=[btn], |
|
outputs=[show_img, chatbot], |
|
) |
|
|
|
submit_btn.click( |
|
fn=add_text, |
|
inputs=[chatbot, txt], |
|
outputs=[ |
|
chatbot, |
|
], |
|
queue=False, |
|
).success( |
|
fn=get_response, inputs=[chatbot, txt, btn], outputs=[chatbot, txt] |
|
).success( |
|
fn=render_file, inputs=[btn], outputs=[show_img] |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|
|
|
|
|