Spaces:
Running
Running
from langchain_core.prompts import PromptTemplate | |
import os | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.llms.ctransformers import CTransformers | |
from langchain.chains.retrieval_qa.base import RetrievalQA | |
import streamlit as st | |
DB_FAISS_PATH = 'vectorstores/' | |
custom_prompt_template = '''use the following pieces of information to answer the user's questions. | |
If you don't know the answer, please just say that don't know the answer, don't try to make up an answer. | |
Context : {context} | |
Question : {question} | |
only return the helpful answer below and nothing else. | |
''' | |
# custom_prompt_template = ''' | |
# <|im_start|>system | |
# use the following pieces of information to answer the user's questions. | |
# If you don't know the answer, please just say that don't know the answer, don't try to make up an answer. | |
# Context : {context} | |
# Question : {question} | |
# only return the helpful answer below and nothing else. | |
# ''' | |
def set_custom_prompt(): | |
""" | |
Prompt template for QA retrieval for vector stores | |
""" | |
prompt = PromptTemplate(template=custom_prompt_template, | |
input_variables=['context', 'question']) | |
return prompt | |
def load_llm(): | |
llm = CTransformers( | |
model='epfl-llm/meditron-7b', | |
model_type='llma', | |
max_new_token=512, | |
temperature=0.5 | |
) | |
return llm | |
def load_embeddings(): | |
embeddings = HuggingFaceBgeEmbeddings(model_name='NeuML/pubmedbert-base-embeddings', | |
model_kwargs={'device': 'cpu'}) | |
return embeddings | |
def load_faiss_index(embeddings): | |
db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True) | |
return db | |
def retrieval_qa_chain(llm, prompt, db): | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type='stuff', | |
retriever=db.as_retriever(search_kwargs={'k': 2}), | |
return_source_documents=True, | |
chain_type_kwargs={'prompt': prompt} | |
) | |
return qa_chain | |
def qa_bot(): | |
embeddings = load_embeddings() | |
db = load_faiss_index(embeddings) | |
llm = load_llm() | |
qa_prompt = set_custom_prompt() | |
qa = retrieval_qa_chain(llm, qa_prompt, db) | |
return qa | |
def final_result(query): | |
qa_result = qa_bot() | |
response = qa_result({'query': query}) | |
return response | |
import streamlit as st | |
# Initialize the bot | |
bot = qa_bot() | |
# Streamlit webpage title | |
st.title('Medical Chatbot') | |
# User input | |
user_query = st.text_input("Please enter your question:") | |
# Button to get answer | |
if st.button('Get Answer'): | |
if user_query: | |
# Call the function from your chatbot script | |
response = final_result(user_query) | |
if response: | |
# Displaying the response | |
st.write("### Answer") | |
st.write(response['result']) | |
# Displaying source document details if available | |
if 'source_documents' in response: | |
st.write("### Source Document Information") | |
for doc in response['source_documents']: | |
# Retrieve and format page content by replacing '\n' with new line | |
formatted_content = doc.page_content.replace("\\n", "\n") | |
st.write("#### Document Content") | |
st.text_area(label="Page Content", value=formatted_content, height=300) | |
# Retrieve source and page from metadata | |
source = doc.metadata.get('source', 'Unknown') | |
page = doc.metadata.get('page', 'Unknown') | |
st.write(f"Source: {source}") | |
st.write(f"Page Number: {page}") | |
else: | |
st.write("Sorry, I couldn't find an answer to your question.") | |
else: | |
st.write("Please enter a question to get an answer.") | |