import gradio as gr
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
import random
import groq  # Assuming you are using the Groq library

# Load the dataset
ds = load_dataset("Amod/mental_health_counseling_conversations")

# Extract columns (updated to match dataset column names)
context = ds["train"]["Context"]  # Column name is 'Context'
response = ds["train"]["Response"]  # Column name is 'Response'

# Load T5 model (small version)
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Directly input the Groq API key (replace with your actual API key)
api_key = "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"
client = groq.Client(api_key=api_key)

# Function to simulate conversation
def chatbot(user_input):
    if not user_input.strip():
        return "Please enter a question or concern to receive guidance."

    # Calculate the word count and remaining characters for the input
    word_count = len(user_input.split())
    max_words = 50  # Max words allowed for input
    remaining_words = max_words - word_count

    if remaining_words < 0:
        return f"Your input is too long. Please limit to {max_words} words. Words remaining: 0."

    # Try using the Groq API for the personalized response
    try:
        brief_response = client.predict(user_input)  # Make sure this method exists for your Groq client
    except Exception as e:
        brief_response = None  # If Groq fails, fall back to dataset

    if brief_response:
        return f"**Personalized Response:** {brief_response}"

    # If Groq API does not work, fallback to dataset
    idx = random.randint(0, len(context) - 1)
    context_text = context[idx]
    response_text = response[idx]

    # Generate response using T5 (RAG approach)
    inputs = tokenizer.encode("summarize: " + user_input, return_tensors="pt", max_length=512, truncation=True)
    summary_ids = model.generate(inputs, max_length=100, num_beams=4, early_stopping=True)
    generated_response = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    if not generated_response:
        return "Oops, sorry, I don't have information about your specific problem. Please visit a doctor to prevent mishaps."

    # Final response combining generated answer and dataset info
    complete_response = (
        f"**Contextual Information:**\n{context_text}\n\n"
        f"**Generated Response:**\n{generated_response}\n\n"
        f"**Fallback Response:**\n{response_text}"
    )

    return f"{complete_response}\n\nWords entered: {word_count}, Words remaining: {remaining_words}"

# Gradio interface setup
interface = gr.Interface(
    fn=chatbot,
    inputs=gr.Textbox(
        label="Ask your question:",
        placeholder="Describe how you're feeling today...",
        lines=4
    ),
    outputs=gr.Markdown(label="Psychologist Assistant Response"),
    title="Virtual Psychiatrist Assistant",
    description="Enter your mental health concerns, and receive guidance and responses from a trained assistant.",
    theme="huggingface",  # Optional: apply a theme if available
)

# Launch the app
interface.launch()