File size: 4,215 Bytes
77b9c3f
 
 
08eb120
 
 
77b9c3f
08eb120
77b9c3f
 
 
 
 
 
 
 
 
 
 
 
 
 
e00a19f
 
77b9c3f
e00a19f
 
 
 
 
 
 
 
 
 
 
 
77b9c3f
 
 
 
 
 
 
 
 
 
e00a19f
 
77b9c3f
e00a19f
 
 
 
 
 
 
 
 
 
 
 
77b9c3f
 
 
 
 
08eb120
 
77b9c3f
08eb120
ac7075c
e00a19f
 
77b9c3f
 
 
 
 
 
 
e00a19f
 
 
 
77b9c3f
 
 
 
 
ac7075c
77b9c3f
 
 
 
 
 
 
 
 
 
 
 
 
e00a19f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
import os
import requests
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from datasets import load_dataset

# Groq API key setup
os.environ["GROQ_API_KEY"] = "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"

# Groq API endpoint
GROQ_API_URL = "https://api.groq.com/v1/inference"

# Function to perform embedding retrieval using MiniLM via Groq API
def retrieve_embedding(user_query):
    payload = {
        "model": "microsoft/MiniLM-L6-H384-uncased",
        "input_text": user_query
    }
    headers = {
        "Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
    }
    
    # Make the API request
    response = requests.post(f"{GROQ_API_URL}/embedding", json=payload, headers=headers)
    
    # Check for errors and return the embedding if available
    if response.status_code == 200:
        json_response = response.json()
        if "embedding" in json_response:
            return json_response["embedding"]
        else:
            st.error("The response from the API did not contain an embedding. Please check the API.")
            return None
    else:
        st.error(f"Failed to retrieve embedding. Status code: {response.status_code}")
        return None

# Function to perform response generation using FLAN-T5 via Groq API
def generate_response(context):
    payload = {
        "model": "google/flan-t5-small",
        "input_text": f"Given the following context, provide a supportive response: {context}"
    }
    headers = {
        "Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
    }
    
    # Make the API request
    response = requests.post(f"{GROQ_API_URL}/generate", json=payload, headers=headers)
    
    # Check for errors and return the response text if available
    if response.status_code == 200:
        json_response = response.json()
        if "text" in json_response:
            return json_response["text"]
        else:
            st.error("The response from the API did not contain a 'text' key.")
            return None
    else:
        st.error(f"Failed to generate response. Status code: {response.status_code}")
        return None

# Load the counseling conversations dataset
dataset = load_dataset("Amod/mental_health_counseling_conversations")["train"]

# Precompute embeddings for the dataset responses using Groq API
@st.cache_resource
def embed_dataset(_dataset):
    embeddings = []
    for entry in _dataset:
        embedding = retrieve_embedding(entry["Response"])
        if embedding is not None:
            embeddings.append(embedding)
    return embeddings

dataset_embeddings = embed_dataset(dataset)

# Function to retrieve closest responses from the dataset using cosine similarity
def retrieve_response(user_query, dataset, dataset_embeddings, k=5):
    query_embedding = retrieve_embedding(user_query)
    if query_embedding is None:
        st.error("Could not retrieve an embedding for the query.")
        return []
    
    cos_scores = cosine_similarity([query_embedding], dataset_embeddings)[0]
    top_indices = np.argsort(cos_scores)[-k:][::-1]

    retrieved_responses = []
    for idx in top_indices:
        retrieved_responses.append(dataset[idx]["Response"])
    return retrieved_responses

# Streamlit app UI
st.title("Emotional Support Buddy")
st.write("Enter your thoughts or concerns, and I'll provide some comforting words.")

# User input
user_query = st.text_input("How are you feeling today?")

if user_query:
    # Retrieve similar responses from the dataset
    retrieved_responses = retrieve_response(user_query, dataset, dataset_embeddings)
    
    if retrieved_responses:
        # Join retrieved responses to create a supportive context
        context = " ".join(retrieved_responses)
        
        # Generate a supportive response using FLAN-T5 via Groq API
        supportive_response = generate_response(context)
        
        if supportive_response:
            st.write("Here's some advice or support for you:")
            st.write(supportive_response)
        else:
            st.write("Sorry, I couldn't generate a response at the moment.")
    else:
        st.write("Sorry, I couldn't find any relevant responses.")