|
from langchain_core.prompts import PromptTemplate |
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.llms.ctransformers import CTransformers |
|
from langchain.chains.retrieval_qa.base import RetrievalQA |
|
from langchain_community.llms import HuggingFaceHub |
|
|
|
from langchain.document_loaders import PyPDFLoader |
|
from langchain.document_loaders import PyPDFDirectoryLoader |
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores import FAISS |
|
|
|
from langchain_community.embeddings import HuggingFaceBgeEmbeddings |
|
|
|
from langchain.prompts import PromptTemplate |
|
|
|
from langchain.chains import create_retrieval_chain |
|
from langchain.chains import RetrievalQA |
|
|
|
from langchain.chains.combine_documents import create_stuff_documents_chain |
|
|
|
import os |
|
import streamlit as st |
|
import fitz |
|
from PIL import Image |
|
import io |
|
|
|
DB_FAISS_PATH = 'vectorstores/' |
|
pdf_path = 'Oxford/Oxford-psychiatric-handbook-1-760.pdf' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_prompt_template = prompt_template=""" |
|
Use the following piece of context to answer the question asked. |
|
Please try to provide the answer only based on the context |
|
{context} |
|
Question:{question} |
|
""" |
|
def set_custom_prompt(): |
|
""" |
|
Prompt template for QA retrieval for vector stores |
|
""" |
|
prompt = PromptTemplate(template = custom_prompt_template, |
|
input_variables = ['context','question']) |
|
|
|
return prompt |
|
|
|
|
|
def load_llm(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
llm = HuggingFaceHub( |
|
repo_id = "mistralai/Mistral-7B-v0.1", |
|
model_kwargs = {'temperature': 0.1, "max_length": 500} |
|
) |
|
return llm |
|
|
|
def retrieval_qa_chain(llm,prompt,db): |
|
qa_chain = RetrievalQA.from_chain_type( |
|
llm = llm, |
|
chain_type = 'stuff', |
|
retriever = db.as_retriever(search_type = 'similarity',search_kwargs = {'k': 3}), |
|
return_source_documents = True, |
|
chain_type_kwargs = {'prompt': prompt} |
|
) |
|
|
|
return qa_chain |
|
|
|
def qa_bot(): |
|
embeddings = HuggingFaceBgeEmbeddings(model_name = 'BAAI/bge-small-en-v1.5', |
|
model_kwargs = {'device':'cpu'}, |
|
encode_kwargs = {'normalize_embeddings': True}) |
|
|
|
|
|
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) |
|
llm = load_llm() |
|
qa_prompt = set_custom_prompt() |
|
qa = retrieval_qa_chain(llm,qa_prompt, db) |
|
|
|
return qa |
|
|
|
def final_result(query): |
|
qa_result = qa_bot() |
|
response = qa_result({'query' : query}) |
|
|
|
return response |
|
|
|
def get_pdf_page_as_image(pdf_path, page_number): |
|
document = fitz.open(pdf_path) |
|
page = document.load_page(page_number) |
|
pix = page.get_pixmap() |
|
img = Image.open(io.BytesIO(pix.tobytes())) |
|
return img |
|
|
|
|
|
st.title('Medical Chatbot') |
|
|
|
|
|
user_query = st.text_input("Please enter your question:") |
|
|
|
|
|
if st.button('Get Answer'): |
|
if user_query: |
|
|
|
response = final_result(user_query) |
|
if response: |
|
|
|
st.write("### Answer") |
|
st.write(response['result']) |
|
|
|
|
|
if 'source_documents' in response: |
|
st.write("### Source Document Information") |
|
for doc in response['source_documents']: |
|
|
|
formatted_content = doc.page_content.replace("\\n", "\n") |
|
st.write("#### Document Content") |
|
st.text_area(label="Page Content", value=formatted_content, height=300) |
|
|
|
|
|
source = doc.metadata['source'] |
|
page = doc.metadata['page'] |
|
st.write(f"Source: {source}") |
|
st.write(f"Page Number: {page+1}") |
|
|
|
|
|
|
|
pdf_page_image = get_pdf_page_as_image(pdf_path, page) |
|
st.image(pdf_page_image, caption=f"Page {page+1} from {source}") |
|
|
|
else: |
|
st.write("Sorry, I couldn't find an answer to your question.") |
|
else: |
|
st.write("Please enter a question to get an answer.") |
|
|