nyp_chatbot / model.py
brandonongsc's picture
Upload model.py
4bc5561 verified
raw
history blame
5.99 kB
import streamlit as st
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import geocoder
from geopy.distance import geodesic #help calc nearest clinic
import pandas as pd
import folium #using folium maps allows us to show richer details on the map like tooltips
from streamlit_folium import folium_static
DB_FAISS_PATH = 'vectorstores/db_faiss'
custom_prompt_template = """Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context: {context}
Question: {question}
Only return the helpful answer below and nothing else.
Helpful answer:
"""
def set_custom_prompt():
prompt = PromptTemplate(template=custom_prompt_template,
input_variables=['context', 'question'])
return prompt
def retrieval_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type='stuff',
retriever=db.as_retriever(search_kwargs={'k': 2}),
return_source_documents=True,
chain_type_kwargs={'prompt': prompt}
)
return qa_chain
def load_llm():
llm = CTransformers(
model="TheBloke/Llama-2-7B-Chat-GGML",
model_type="llama",
max_new_tokens=512,
temperature=0.5
)
return llm
import folium
def qa_bot(query):
if 'nearest TCM clinic' in query:
# Get user's current location
g = geocoder.ip('me')
user_lat, user_lon = g.latlng
# Load locations from the CSV file
locations_df = pd.read_csv("dataset/locations.csv")
# Filter locations within 5km from user's current location
filtered_locations_df = locations_df[locations_df.apply(lambda row: geodesic((user_lat, user_lon), (row['latitude'], row['longitude'])).kilometers <= 5, axis=1)]
# Create map centered at user's location
my_map = folium.Map(location=[user_lat, user_lon], zoom_start=12)
# Add markers with custom tooltips for filtered locations
for index, location in filtered_locations_df.iterrows():
folium.Marker(location=[location['latitude'], location['longitude']], tooltip=f"{location['name']}<br>Reviews: {location['Stars_review']}<br>Avg Price $: {location['Price']}<br>Contact No: {location['Contact']}").add_to(my_map)
# Display map
folium_static(my_map)
return "Displaying locations within 5km from your current location."
else:
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={'device': 'cpu'})
db = FAISS.load_local(DB_FAISS_PATH, embeddings)
llm = load_llm()
qa_prompt = set_custom_prompt()
qa = retrieval_qa_chain(llm, qa_prompt, db)
# Implement the question-answering logic here
response = qa({'query': query})
return response['result']
def add_vertical_space(spaces=1):
for _ in range(spaces):
st.markdown("---")
def main():
st.set_page_config(page_title="Ask me anything about TCM")
with st.sidebar:
st.title('Welcome to Nexus AI TCM!')
st.markdown('''
<style>
[data-testid=stSidebar] {
background-color: #ffffff;
}
</style>
<img src="http://40.90.239.142/bongvm/img/nexus_logo4.png" width=200>
''', unsafe_allow_html=True)
add_vertical_space(1) # Adjust the number of spaces as needed
#st.write('Made by [@ThisIs-Developer](https://huggingface.co/ThisIs-Developer)')
st.title("Nexus AI TCM")
st.markdown(
"""
<style>
.chat-container {
display: flex;
flex-direction: column;
height: 400px;
overflow-y: auto;
padding: 10px;
color: white; /* Font color */
}
.user-bubble {
background-color: #007bff; /* Blue color for user */
align-self: flex-end;
border-radius: 10px;
padding: 8px;
margin: 5px;
max-width: 70%;
word-wrap: break-word;
}
.bot-bubble {
background-color: #363636; /* Slightly lighter background color */
align-self: flex-start;
border-radius: 10px;
padding: 8px;
margin: 5px;
max-width: 70%;
word-wrap: break-word;
}
</style>
"""
, unsafe_allow_html=True)
conversation = st.session_state.get("conversation", [])
query = st.text_input("Ask your question here:", key="user_input")
if st.button("Get Answer"):
if query:
with st.spinner("Processing your question..."): # Display the processing message
conversation.append({"role": "user", "message": query})
# Call your QA function
answer = qa_bot(query)
conversation.append({"role": "bot", "message": answer})
st.session_state.conversation = conversation
else:
st.warning("Please input a question.")
chat_container = st.empty()
chat_bubbles = ''.join([f'<div class="{c["role"]}-bubble">{c["message"]}</div>' for c in conversation])
chat_container.markdown(f'<div class="chat-container">{chat_bubbles}</div>', unsafe_allow_html=True)
if __name__ == "__main__":
main()