Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import streamlit as st | |
from streamlit.runtime.scriptrunner import RerunException, StopException | |
from openai import OpenAI | |
from pymongo import MongoClient | |
from pinecone import Pinecone | |
import uuid | |
from datetime import datetime | |
import time | |
from streamlit.runtime.caching import cache_data | |
from streamlit_autorefresh import st_autorefresh | |
# Load environment variables | |
load_dotenv() | |
# Configuration | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
MONGODB_URI = os.getenv("MONGODB_URI") | |
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
PINECONE_ENVIRONMENT = os.getenv("PINECONE_ENVIRONMENT") | |
PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME") | |
GLOBAL_MEMORY_ID = "global_common_memory_id" # Added GLOBAL_MEMORY_ID | |
openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
mongo_client = MongoClient(MONGODB_URI) | |
db = mongo_client["Wall_Street"] | |
conversation_history = db["conversation_history"] | |
global_common_memory = db["global_common_memory"] # New global common memory collection | |
# Initialize GLOBAL_MEMORY_ID if it doesn't exist | |
if not global_common_memory.find_one({"memory_id": GLOBAL_MEMORY_ID}): | |
global_common_memory.insert_one({ | |
"memory_id": GLOBAL_MEMORY_ID, | |
"memory": [] | |
}) | |
# Initialize Pinecone | |
pc = Pinecone(api_key=PINECONE_API_KEY) | |
pinecone_index = pc.Index(PINECONE_INDEX_NAME) | |
# Set up Streamlit page configuration | |
st.set_page_config(page_title="GPT-Driven Chat System - Tester", page_icon="🔬", layout="wide") | |
# Custom CSS to improve the UI | |
st.markdown(""" | |
<style> | |
/* Your custom CSS styles */ | |
</style> | |
""", unsafe_allow_html=True) | |
# Initialize Streamlit session state | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
if 'user_type' not in st.session_state: | |
st.session_state['user_type'] = None | |
if 'session_id' not in st.session_state: | |
st.session_state['session_id'] = str(uuid.uuid4()) | |
# --- Common Memory Functions --- | |
# Cache for 5 minutes | |
def get_global_common_memory(): | |
"""Retrieve the global common memory.""" | |
memory_doc = global_common_memory.find_one({"memory_id": GLOBAL_MEMORY_ID}) | |
return memory_doc.get('memory', []) if memory_doc else [] | |
def append_to_global_common_memory(content): | |
"""Append content to the global common memory.""" | |
try: | |
# First, ensure the document exists with an initialized memory array | |
global_common_memory.update_one( | |
{"memory_id": GLOBAL_MEMORY_ID}, | |
{"$setOnInsert": {"memory": []}}, | |
upsert=True | |
) | |
# Then, add the new content to the memory array | |
result = global_common_memory.update_one( | |
{"memory_id": GLOBAL_MEMORY_ID}, | |
{"$push": {"memory": content}} | |
) | |
# Invalidate the cache after updating | |
get_global_common_memory.clear() | |
st.success("Memory appended successfully!") | |
# Instead of using st.rerun(), we'll set a flag in session state | |
st.session_state['memory_updated'] = True | |
except Exception as e: | |
st.error(f"Failed to append to global common memory: {str(e)}") | |
def clear_global_common_memory(): | |
"""Clear all items from the global common memory.""" | |
try: | |
global_common_memory.update_one( | |
{"memory_id": GLOBAL_MEMORY_ID}, | |
{"$set": {"memory": []}}, | |
upsert=True | |
) | |
# Invalidate the cache after clearing | |
get_global_common_memory.clear() | |
st.success("Global common memory cleared successfully!") | |
except Exception as e: | |
st.error(f"Failed to clear global common memory: {str(e)}") | |
# --- Relevant Context Retrieval --- | |
# Cache for 1 minute | |
def get_relevant_context(query, top_k=3): | |
""" | |
Retrieve relevant context from Pinecone based on the user query. | |
""" | |
try: | |
query_embedding = openai_client.embeddings.create( | |
model="text-embedding-3-large", # Updated to use the larger model | |
input=query | |
).data[0].embedding | |
results = pinecone_index.query(vector=query_embedding, top_k=top_k, include_metadata=True) | |
contexts = [item['metadata']['text'] for item in results['matches']] | |
return " ".join(contexts) | |
except Exception as e: | |
print(f"Error retrieving context: {str(e)}") | |
return "" | |
# --- GPT Response Function --- | |
def get_gpt_response(prompt, context="", system_message=None): | |
try: | |
common_memory = get_global_common_memory() | |
system_msg = ( | |
"You are a helpful assistant. Use the following context and global common memory " | |
"to inform your responses, but don't mention them explicitly unless directly relevant to the user's question." | |
) | |
if system_message: | |
system_msg += f"\n\nTrainer Instructions:\n{system_message}" | |
if common_memory: | |
memory_str = "\n".join(common_memory) | |
system_msg += f"\n\nGlobal Common Memory:\n{memory_str}" | |
messages = [ | |
{"role": "system", "content": system_msg}, | |
{"role": "user", "content": f"Context: {context}\n\nUser query: {prompt}"} | |
] | |
completion = openai_client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=messages | |
) | |
response = completion.choices[0].message.content.strip() | |
return response | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return None | |
# --- Send User Message --- | |
def send_message(message): | |
""" | |
Sends a user message. If admin takeover is active, messages are sent to admin instead of GPT. | |
""" | |
context = get_relevant_context(message) | |
user_message = { | |
"role": "user", | |
"content": message, | |
"timestamp": datetime.utcnow(), | |
"status": "approved" # User messages are always approved | |
} | |
# Upsert the user message immediately | |
result = conversation_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": user_message}, | |
"$set": {"last_updated": datetime.utcnow()}, | |
"$setOnInsert": {"created_at": datetime.utcnow()} | |
}, | |
upsert=True | |
) | |
# Update the session state with the user message | |
st.session_state['chat_history'].append(user_message) | |
if not st.session_state.get('admin_takeover_active'): | |
# Generate GPT response if takeover is not active | |
gpt_response = get_gpt_response(message, context) | |
assistant_message = { | |
"role": "assistant", | |
"content": gpt_response, | |
"timestamp": datetime.utcnow(), | |
"status": "pending" # Set status to pending for admin approval | |
} | |
# Upsert the assistant message | |
result = conversation_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": assistant_message}, | |
"$set": {"last_updated": datetime.utcnow()} | |
} | |
) | |
# Update the session state with the assistant message | |
st.session_state['chat_history'].append(assistant_message) | |
# --- Send Admin Message --- | |
def send_admin_message(message): | |
""" | |
Sends an admin message directly to the user during a takeover. | |
""" | |
admin_message = { | |
"role": "admin", | |
"content": message, | |
"timestamp": datetime.utcnow(), | |
"status": "approved" | |
} | |
# Upsert the admin message | |
result = conversation_history.update_one( | |
{"session_id": st.session_state['session_id']}, | |
{ | |
"$push": {"messages": admin_message}, | |
"$set": {"last_updated": datetime.utcnow()} | |
} | |
) | |
# Update the session state with the admin message | |
st.session_state['chat_history'].append(admin_message) | |
# --- Takeover Functions --- | |
def activate_takeover(session_id): | |
""" | |
Activates takeover mode for the given session. | |
""" | |
try: | |
db.takeover_status.update_one( | |
{"session_id": session_id}, | |
{"$set": {"active": True, "activated_at": datetime.utcnow()}}, | |
upsert=True | |
) | |
st.success(f"Takeover activated for session {session_id[:8]}...") | |
except Exception as e: | |
st.error(f"Failed to activate takeover: {str(e)}") | |
def deactivate_takeover(session_id): | |
""" | |
Deactivates takeover mode for the given session. | |
""" | |
try: | |
db.takeover_status.update_one( | |
{"session_id": session_id}, | |
{"$set": {"active": False}}, | |
) | |
st.success(f"Takeover deactivated for session {session_id[:8]}...") | |
except Exception as e: | |
st.error(f"Failed to deactivate takeover: {str(e)}") | |
def handle_admin_takeover(session_id): | |
st.subheader("Admin Takeover") | |
takeover_active = db.takeover_status.find_one({"session_id": session_id}) | |
is_active = takeover_active.get("active", False) if takeover_active else False | |
if is_active: | |
st.info("Takeover is currently active for this session.") | |
if st.button("Deactivate Takeover"): | |
deactivate_takeover(session_id) | |
st.success("Takeover deactivated.") | |
st.rerun() | |
else: | |
st.warning("Takeover is not active for this session.") | |
if st.button("Activate Takeover"): | |
activate_takeover(session_id) | |
st.success("Takeover activated.") | |
st.rerun() | |
if is_active: | |
admin_message = st.text_area("Send Message to User", key="admin_message") | |
if st.button("Send Admin Message"): | |
admin_message = st.session_state.get("admin_message", "") | |
if admin_message.strip(): | |
send_admin_message(admin_message.strip()) | |
st.success("Admin message sent successfully!") | |
st.session_state["admin_message"] = "" | |
else: | |
st.warning("Please enter a message to send.") | |
# --- View Full Chat (User Perspective) --- | |
def view_full_chat(session_id): | |
st.title(f"Full Chat View - Session: {session_id[:8]}...") | |
chat = db.chat_history.find_one({"session_id": session_id}) | |
if not chat: | |
st.error("Chat not found.") | |
return | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.subheader(f"Session ID: {session_id}") | |
with col2: | |
st.write(f"Last Updated: {chat.get('last_updated', 'N/A')}") | |
st.markdown("---") | |
for message in chat.get('messages', []): | |
role = message['role'].capitalize() | |
content = message['content'] | |
timestamp = message.get('timestamp', 'N/A') | |
if role == 'User': | |
with st.chat_message("user"): | |
st.markdown(f"**User** - {timestamp}") | |
st.markdown(content) | |
elif role == 'Assistant': | |
with st.chat_message("assistant"): | |
st.markdown(f"**Assistant** - {timestamp}") | |
st.markdown(content) | |
elif role == 'Admin': | |
with st.chat_message("human"): | |
st.markdown(f"**Admin** - {timestamp}") | |
st.markdown(content) | |
st.markdown("---") | |
# Add text box to append to global memory | |
st.subheader("Add to Global Memory") | |
new_memory = st.text_area("Enter new memory item", key=f"new_memory_input_{session_id}") | |
if st.button("Add Memory", key=f"add_memory_button_{session_id}"): | |
if new_memory.strip(): | |
append_to_global_common_memory(new_memory.strip()) | |
st.success("New memory item added to global memory!") | |
# Instead of rerunning, we'll update the session state | |
st.session_state[f'memory_added_{session_id}'] = True | |
st.rerun() | |
else: | |
st.warning("Please enter a valid memory item.") | |
# Display success message if memory was added | |
if st.session_state.get(f'memory_added_{session_id}'): | |
st.success("Memory item added successfully!") | |
# Clear the flag | |
del st.session_state[f'memory_added_{session_id}'] | |
st.markdown("---") | |
col1, col2, col3 = st.columns([1, 1, 1]) | |
with col2: | |
if st.button("Back to Chat History", use_container_width=True): | |
st.session_state.pop('full_chat_view', None) | |
st.rerun() | |
# --- Clear Global Chat Memory--- | |
def clear_global_common_memory(): | |
"""Clear all items from the global common memory.""" | |
try: | |
global_common_memory.update_one( | |
{"memory_id": GLOBAL_MEMORY_ID}, | |
{"$set": {"memory": []}}, | |
upsert=True | |
) | |
# Invalidate the cache after clearing | |
get_global_common_memory.clear() | |
st.success("Global common memory cleared successfully!") | |
except Exception as e: | |
st.error(f"Failed to clear global common memory: {str(e)}") | |
def display_chat_history(): | |
st.subheader("All Chat History") | |
all_chats = list(db.chat_history.find().sort("last_updated", -1)) | |
if not all_chats: | |
st.info("No chat history found.") | |
return | |
for idx, chat in enumerate(all_chats): | |
session_id = chat['session_id'] | |
last_updated = chat.get('last_updated', 'N/A') | |
with st.expander(f"Session: {session_id[:8]}... - Last Updated: {last_updated}"): | |
if chat.get('messages'): | |
last_message = chat['messages'][-1] | |
st.markdown(f"Last message ({last_message['role'].capitalize()}):") | |
st.markdown(f"> {last_message['content'][:100]}...") | |
if st.button(f"Show Full Chat", key=f"show_full_chat_{idx}"): | |
st.session_state['full_chat_view'] = session_id | |
st.rerun() | |
def trainer_intervention_tab(): | |
st.subheader("Trainer Intervention") | |
# Handle admin intervention | |
handle_admin_intervention() | |
def handle_admin_intervention(): | |
st.subheader("Review Pending Responses") | |
pending_responses = conversation_history.find( | |
{"messages.role": "assistant", "messages.status": "pending"} | |
) | |
for conversation in pending_responses: | |
st.write(f"Session ID: {conversation['session_id'][:8]}...") | |
for i, message in enumerate(conversation['messages']): | |
if message['role'] == 'assistant' and message.get('status') == 'pending': | |
user_message = conversation['messages'][i-1]['content'] if i > 0 else "N/A" | |
st.write(f"**User:** {user_message}") | |
st.write(f"**GPT:** {message['content']}") | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
if st.button("Approve", key=f"approve_{conversation['session_id']}_{i}"): | |
if approve_response(conversation['session_id'], i): | |
st.success("Response approved") | |
time.sleep(0.5) | |
st.rerun() | |
with col2: | |
if st.button("Modify", key=f"modify_{conversation['session_id']}_{i}"): | |
st.session_state['modifying'] = (conversation['session_id'], i) | |
st.rerun() | |
with col3: | |
if st.button("Regenerate", key=f"regenerate_{conversation['session_id']}_{i}"): | |
st.session_state['regenerating'] = (conversation['session_id'], i) | |
st.rerun() | |
with col4: | |
takeover_doc = db.takeover_status.find_one({"session_id": conversation['session_id']}) | |
takeover_active = takeover_doc.get("active", False) if takeover_doc else False | |
if takeover_active: | |
if st.button("Deactivate Takeover", key=f"deactivate_takeover_{conversation['session_id']}_{i}"): | |
deactivate_takeover(conversation['session_id']) | |
st.success("Takeover deactivated.") | |
st.rerun() | |
else: | |
if st.button("Activate Takeover", key=f"activate_takeover_{conversation['session_id']}_{i}"): | |
activate_takeover(conversation['session_id']) | |
st.success("Takeover activated.") | |
st.rerun() | |
st.divider() | |
if 'regenerating' in st.session_state: | |
session_id, message_index = st.session_state['regenerating'] | |
with st.form(key="regenerate_form"): | |
operator_input = st.text_input("Enter additional instructions for regeneration:") | |
submit_button = st.form_submit_button("Submit") | |
if submit_button: | |
del st.session_state['regenerating'] | |
regenerate_response(session_id, message_index, operator_input) | |
st.success("Response regenerated with operator input.") | |
st.rerun() | |
if 'modifying' in st.session_state: | |
session_id, message_index = st.session_state['modifying'] | |
conversation = conversation_history.find_one({"session_id": session_id}) | |
message = conversation['messages'][message_index] | |
modified_content = st.text_area("Modify the response:", value=message['content']) | |
if st.button("Save Modification"): | |
save_modified_response(session_id, message_index, modified_content) | |
st.success("Response modified and approved") | |
del st.session_state['modifying'] | |
st.rerun() | |
def approve_response(session_id, message_index): | |
try: | |
result = conversation_history.update_one( | |
{"session_id": session_id}, | |
{"$set": {f"messages.{message_index}.status": "approved"}} | |
) | |
return result.modified_count > 0 | |
except Exception as e: | |
st.error(f"Failed to approve response: {str(e)}") | |
return False | |
def save_modified_response(session_id, message_index, modified_content): | |
try: | |
conversation_history.update_one( | |
{"session_id": session_id}, | |
{ | |
"$set": { | |
f"messages.{message_index}.content": modified_content, | |
f"messages.{message_index}.status": "approved" | |
} | |
} | |
) | |
except Exception as e: | |
st.error(f"Failed to save modified response: {str(e)}") | |
def regenerate_response(session_id, message_index, operator_input): | |
try: | |
conversation = conversation_history.find_one({"session_id": session_id}) | |
user_message = conversation['messages'][message_index - 1]['content'] if message_index > 0 else "" | |
new_response = get_gpt_response(user_message, system_message=operator_input) | |
conversation_history.update_one( | |
{"session_id": session_id}, | |
{ | |
"$set": { | |
f"messages.{message_index}.content": new_response, | |
f"messages.{message_index}.status": "pending" | |
} | |
} | |
) | |
except Exception as e: | |
st.error(f"Failed to regenerate response: {str(e)}") | |
def trainer_page(): | |
st.title("Trainer Dashboard") | |
# Add auto-refresh every 10 seconds (10000 milliseconds) | |
st_autorefresh(interval=10000, limit=None, key="trainer_autorefresh") | |
tab1, tab2, tab3 = st.tabs(["Current Status", "Chat History", "Intervention"]) | |
with tab1: | |
# Display current global memory | |
st.subheader("Current Global Memory") | |
global_memory = get_global_common_memory() | |
if global_memory: | |
for idx, item in enumerate(global_memory, 1): | |
st.text(f"{idx}. {item}") | |
else: | |
st.info("No global memory items found.") | |
# Add button to clear global memory | |
if st.button("Clear Global Memory", key="clear_global_memory"): | |
clear_global_common_memory() | |
st.success("Global memory cleared successfully!") | |
time.sleep(1) | |
st.rerun() | |
# Display current chats | |
st.subheader("Active Chats") | |
chats = list(conversation_history.find().sort("last_updated", -1).limit(5)) | |
for idx, chat in enumerate(chats): | |
with st.expander(f"Session: {chat['session_id'][:8]}... - Last Updated: {chat.get('last_updated', 'N/A')}"): | |
for message in chat.get('messages', [])[-5:]: | |
role = message['role'].capitalize() | |
content = message['content'] | |
st.markdown(f"**{role}:** {content}") | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
if st.button(f"View Full Chat", key=f"view_chat_{idx}"): | |
st.session_state['selected_chat'] = chat['session_id'] | |
st.rerun() | |
with col2: | |
takeover_doc = db.takeover_status.find_one({"session_id": chat['session_id']}) | |
takeover_active = takeover_doc.get("active", False) if takeover_doc else False | |
if takeover_active: | |
if st.button(f"Deactivate Takeover", key=f"deactivate_takeover_{idx}"): | |
deactivate_takeover(chat['session_id']) | |
st.success("Takeover deactivated.") | |
st.rerun() | |
else: | |
if st.button(f"Activate Takeover", key=f"activate_takeover_{idx}"): | |
activate_takeover(chat['session_id']) | |
st.success("Takeover activated.") | |
st.rerun() | |
with col3: | |
if st.button(f"Delete Chat", key=f"delete_chat_{idx}"): | |
delete_chat(chat['session_id']) | |
st.success(f"Chat {chat['session_id'][:8]}... deleted.") | |
st.rerun() | |
with col4: | |
if takeover_active: | |
st.text_input("Send message", key=f"takeover_message_{idx}") | |
if st.button("Send", key=f"send_takeover_{idx}"): | |
message = st.session_state[f"takeover_message_{idx}"] | |
if message.strip(): | |
send_admin_message(chat['session_id'], message.strip()) | |
st.success("Message sent.") | |
st.rerun() | |
else: | |
st.warning("Please enter a message to send.") | |
# Manual refresh button | |
if st.button("Refresh", key="refresh_button"): | |
st.rerun() | |
with tab2: | |
display_chat_history() | |
with tab3: | |
trainer_intervention_tab() | |
def delete_chat(session_id): | |
try: | |
result = conversation_history.delete_one({"session_id": session_id}) | |
if result.deleted_count == 0: | |
st.error("Failed to delete chat. Please try again.") | |
except Exception as e: | |
st.error(f"Error deleting chat: {str(e)}") | |
# --- Main Function --- | |
def main(): | |
try: | |
if 'memory_updated' in st.session_state: | |
del st.session_state['memory_updated'] | |
st.rerun() | |
if 'full_chat_view' in st.session_state: | |
view_full_chat(st.session_state['full_chat_view']) | |
elif 'selected_chat' in st.session_state: | |
view_full_chat(st.session_state['selected_chat']) | |
else: | |
trainer_page() | |
except (RerunException, StopException): | |
raise | |
except Exception as e: | |
st.error(f"An unexpected error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |