|
import streamlit as st |
|
from llama_guard import moderate_chat, get_category_name |
|
import time |
|
from chat_agent import convo, main |
|
from chat_agent import choose_model1, delete_all_variables |
|
from recommendation_agent import recommend2, choose_model2, is_depressed, start_recommend |
|
from functools import cached_property |
|
from streamlit_js_eval import streamlit_js_eval |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
st.title('BrighterDays Mentor') |
|
|
|
|
|
col1, col2 = st.columns([2, 3]) |
|
|
|
model = st.sidebar.selectbox(label="Choose the LLM model", options=["mistral-7b-base-model", "mental-health-mistral-7b-finetuned-model"]) |
|
print("\n\nSelected LLM model from Dropdown",model) |
|
choose_model1(model) |
|
choose_model2(model) |
|
main() |
|
start_recommend() |
|
|
|
def update_recommendations(sum): |
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
st.divider() |
|
st.write("Potential Mental Health Condition:") |
|
st.write(is_depressed(sum)) |
|
st.header("Mental Health Advice:") |
|
with st.spinner('Thinking...'): |
|
|
|
recommend = recommend2(sum) |
|
st.write(recommend) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cached_property |
|
def get_recommendations(): |
|
return "These are some updated recommendations." |
|
|
|
|
|
def response_generator(response): |
|
''' |
|
responds the text with a type writter effect |
|
''' |
|
response_buffer = response.strip() |
|
for word in response_buffer.split(): |
|
yield word + " " |
|
time.sleep(0.03) |
|
|
|
|
|
def startup(): |
|
with st.chat_message("assistant"): |
|
time.sleep(0.2) |
|
st.markdown("Hi, I am your Mental Health Counselar. How can I help you today?") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
|
|
for message in st.session_state.messages: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
|
|
|
|
if 'llama_guard_enabled' not in st.session_state: |
|
st.session_state['llama_guard_enabled'] = True |
|
|
|
|
|
llama_guard_enabled = st.sidebar.checkbox("Enable LlamaGuard", |
|
value=st.session_state['llama_guard_enabled'], |
|
key="llama_guard_toggle") |
|
|
|
|
|
|
|
st.session_state['llama_guard_enabled'] = llama_guard_enabled |
|
|
|
|
|
|
|
|
|
|
|
|
|
if user_prompt := st.chat_input(""): |
|
st.session_state.messages.append({"role": "user", "content": user_prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(user_prompt) |
|
|
|
with st.chat_message("assistant"): |
|
print('llama guard enabled',st.session_state['llama_guard_enabled']) |
|
is_safe = True |
|
unsafe_category_name = "" |
|
|
|
response = "" |
|
if st.session_state['llama_guard_enabled']: |
|
|
|
guard_status, error = moderate_chat(user_prompt) |
|
if error: |
|
st.error(f"Failed to retrieve data from Llama Gaurd: {error}") |
|
else: |
|
if 'unsafe' in guard_status[0]['generated_text']: |
|
is_safe = False |
|
|
|
unsafe_category_name = get_category_name(guard_status[0]['generated_text']) |
|
print(f'Guard status {guard_status}, Category name {unsafe_category_name}') |
|
if is_safe==False: |
|
|
|
response = f"I see you are asking something about {unsafe_category_name} Due to eithical and safety reasons, I can't provide the help you need. Please reach out to someone who can, like a family member, friend, or therapist. In urgent situations, contact emergency services or a crisis hotline. Remember, asking for help is brave, and you're not alone." |
|
st.write_stream(response_generator(response)) |
|
response,summary = convo("") |
|
st.write_stream(response_generator(response)) |
|
|
|
else: |
|
response,summary = convo(user_prompt) |
|
|
|
time.sleep(0.2) |
|
st.write_stream(response_generator(response)) |
|
print("This is the response from app.py",response) |
|
update_recommendations(summary) |
|
|
|
st.session_state.messages.append({"role": "assistant", "content": response}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.button("Reset Chat"): |
|
delete_all_variables() |
|
streamlit_js_eval(js_expressions="parent.window.location.reload()") |
|
|
|
startup() |