mavinsao's picture
Update app.py
43f3fd4 verified
raw
history blame
2.68 kB
from transformers import pipeline
import streamlit as st
import streamlit.components.v1 as components
# Load the models
pipe_1 = pipeline("text-classification", model="mavinsao/roberta-base-finetuned-mental-health")
pipe_2 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-health")
# Function for ensemble prediction
def ensemble_predict(text):
# Store results from each model
results_1 = pipe_1(text)
results_2 = pipe_2(text)
# Initialize a dictionary with all potential labels to ensure they are considered
ensemble_scores = {}
# Add all labels from the first model's output
for result in results_1:
ensemble_scores[result['label']] = 0
# Add all labels from the second model's output
for result in results_2:
ensemble_scores[result['label']] = 0
# Aggregate scores from both models
for results in [results_1, results_2]:
for result in results:
label = result['label']
score = result['score']
ensemble_scores[label] += score / 2 # Averaging the scores
# Determine the predicted label and confidence
predicted_label = max(ensemble_scores, key=ensemble_scores.get)
confidence = ensemble_scores[predicted_label] # Ensemble confidence
return predicted_label, confidence
# Streamlit app
st.title('Mental Illness Prediction')
# Input text area for user input
sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
if st.button('Predict'):
# Perform the prediction
predicted_label, confidence = ensemble_predict(sentence)
# CSS injection to target the labels
st.markdown("""
<style>
div[data-testid="metric-container"] {
font-weight: bold;
font-size: 18px; /* Adjust the font size as desired */
}
</style>
""", unsafe_allow_html=True)
# Display the result
st.write("Result:", predicted_label)
st.write("Confidence:", confidence)
# Additional reminder after prediction
st.info("Remember: This prediction is not a diagnosis. Always consult with a healthcare professional for proper evaluation and advice.")
# Additional information
st.markdown("""
### About Our Method
Our method is designed to assist mental health professionals, such as psychologists and psychiatrists, rather than replace them. Using our model to directly calculate mental illness labels can introduce biases, potentially leading to inaccurate diagnoses. Therefore, the predictions made by our model should only be used as a reference, with the final diagnosis being carefully determined by qualified professionals.
""")