Zeeshan42's picture
Update app.py
268c7ee verified
raw
history blame
3.23 kB
import gradio as gr
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
import random
import groq # Assuming you are using the Groq library
# Load the dataset
ds = load_dataset("Amod/mental_health_counseling_conversations")
# Extract columns (updated to match dataset column names)
context = ds["train"]["Context"] # Column name is 'Context'
response = ds["train"]["Response"] # Column name is 'Response'
# Load T5 model (small version)
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
# Directly input the Groq API key (replace with your actual API key)
api_key = "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"
client = groq.Client(api_key=api_key)
# Function to simulate conversation
def chatbot(user_input):
if not user_input.strip():
return "Please enter a question or concern to receive guidance."
# Calculate the word count and remaining characters for the input
word_count = len(user_input.split())
max_words = 50 # Max words allowed for input
remaining_words = max_words - word_count
if remaining_words < 0:
return f"Your input is too long. Please limit to {max_words} words. Words remaining: 0."
# Try using the Groq API for the personalized response
try:
brief_response = client.predict(user_input) # Make sure this method exists for your Groq client
except Exception as e:
brief_response = None # If Groq fails, fall back to dataset
if brief_response:
return f"**Personalized Response:** {brief_response}"
# If Groq API does not work, fallback to dataset
idx = random.randint(0, len(context) - 1)
context_text = context[idx]
response_text = response[idx]
# Generate response using T5 (RAG approach)
inputs = tokenizer.encode("summarize: " + user_input, return_tensors="pt", max_length=512, truncation=True)
summary_ids = model.generate(inputs, max_length=100, num_beams=4, early_stopping=True)
generated_response = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
if not generated_response:
return "Oops, sorry, I don't have information about your specific problem. Please visit a doctor to prevent mishaps."
# Final response combining generated answer and dataset info
complete_response = (
f"**Contextual Information:**\n{context_text}\n\n"
f"**Generated Response:**\n{generated_response}\n\n"
f"**Fallback Response:**\n{response_text}"
)
return f"{complete_response}\n\nWords entered: {word_count}, Words remaining: {remaining_words}"
# Gradio interface setup
interface = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(
label="Ask your question:",
placeholder="Describe how you're feeling today...",
lines=4
),
outputs=gr.Markdown(label="Psychologist Assistant Response"),
title="Virtual Psychiatrist Assistant",
description="Enter your mental health concerns, and receive guidance and responses from a trained assistant.",
theme="huggingface", # Optional: apply a theme if available
)
# Launch the app
interface.launch()