File size: 4,215 Bytes
77b9c3f 08eb120 77b9c3f 08eb120 77b9c3f e00a19f 77b9c3f e00a19f 77b9c3f e00a19f 77b9c3f e00a19f 77b9c3f 08eb120 77b9c3f 08eb120 ac7075c e00a19f 77b9c3f e00a19f 77b9c3f ac7075c 77b9c3f e00a19f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
import os
import requests
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from datasets import load_dataset
# Groq API key setup
os.environ["GROQ_API_KEY"] = "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"
# Groq API endpoint
GROQ_API_URL = "https://api.groq.com/v1/inference"
# Function to perform embedding retrieval using MiniLM via Groq API
def retrieve_embedding(user_query):
payload = {
"model": "microsoft/MiniLM-L6-H384-uncased",
"input_text": user_query
}
headers = {
"Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
}
# Make the API request
response = requests.post(f"{GROQ_API_URL}/embedding", json=payload, headers=headers)
# Check for errors and return the embedding if available
if response.status_code == 200:
json_response = response.json()
if "embedding" in json_response:
return json_response["embedding"]
else:
st.error("The response from the API did not contain an embedding. Please check the API.")
return None
else:
st.error(f"Failed to retrieve embedding. Status code: {response.status_code}")
return None
# Function to perform response generation using FLAN-T5 via Groq API
def generate_response(context):
payload = {
"model": "google/flan-t5-small",
"input_text": f"Given the following context, provide a supportive response: {context}"
}
headers = {
"Authorization": f"Bearer {os.getenv('GROQ_API_KEY')}"
}
# Make the API request
response = requests.post(f"{GROQ_API_URL}/generate", json=payload, headers=headers)
# Check for errors and return the response text if available
if response.status_code == 200:
json_response = response.json()
if "text" in json_response:
return json_response["text"]
else:
st.error("The response from the API did not contain a 'text' key.")
return None
else:
st.error(f"Failed to generate response. Status code: {response.status_code}")
return None
# Load the counseling conversations dataset
dataset = load_dataset("Amod/mental_health_counseling_conversations")["train"]
# Precompute embeddings for the dataset responses using Groq API
@st.cache_resource
def embed_dataset(_dataset):
embeddings = []
for entry in _dataset:
embedding = retrieve_embedding(entry["Response"])
if embedding is not None:
embeddings.append(embedding)
return embeddings
dataset_embeddings = embed_dataset(dataset)
# Function to retrieve closest responses from the dataset using cosine similarity
def retrieve_response(user_query, dataset, dataset_embeddings, k=5):
query_embedding = retrieve_embedding(user_query)
if query_embedding is None:
st.error("Could not retrieve an embedding for the query.")
return []
cos_scores = cosine_similarity([query_embedding], dataset_embeddings)[0]
top_indices = np.argsort(cos_scores)[-k:][::-1]
retrieved_responses = []
for idx in top_indices:
retrieved_responses.append(dataset[idx]["Response"])
return retrieved_responses
# Streamlit app UI
st.title("Emotional Support Buddy")
st.write("Enter your thoughts or concerns, and I'll provide some comforting words.")
# User input
user_query = st.text_input("How are you feeling today?")
if user_query:
# Retrieve similar responses from the dataset
retrieved_responses = retrieve_response(user_query, dataset, dataset_embeddings)
if retrieved_responses:
# Join retrieved responses to create a supportive context
context = " ".join(retrieved_responses)
# Generate a supportive response using FLAN-T5 via Groq API
supportive_response = generate_response(context)
if supportive_response:
st.write("Here's some advice or support for you:")
st.write(supportive_response)
else:
st.write("Sorry, I couldn't generate a response at the moment.")
else:
st.write("Sorry, I couldn't find any relevant responses.")
|