Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
import streamlit as st | |
from streamlit.runtime.scriptrunner import RerunException, StopException, RerunData | |
from openai import OpenAI | |
from pymongo import MongoClient | |
from datetime import datetime, timedelta | |
import time | |
from streamlit_autorefresh import st_autorefresh | |
from streamlit.runtime.caching import cache_data | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
# Load environment variables | |
load_dotenv() | |
# Configuration | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
MONGODB_URI = os.getenv("MONGODB_URI") | |
# Initialize clients | |
openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
mongo_client = MongoClient(MONGODB_URI) | |
db = mongo_client["Wall_Street"] | |
conversation_history = db["conversation_history"] | |
trainer_feedback = db["trainer_feedback"] | |
trainer_instructions = db["trainer_instructions"] | |
global_common_memory = db["global_common_memory"] # New global common memory collection | |
# Define a unique identifier for global memory | |
GLOBAL_MEMORY_ID = "global_common_memory_id" | |
# Set up Streamlit page configuration | |
st.set_page_config(page_title="GPT-Driven Chat System - Operator", page_icon="๐ ๏ธ", layout="wide") | |
# Custom CSS to improve the UI | |
st.markdown(""" | |
<style> | |
/* Your custom CSS styles */ | |
</style> | |
""", unsafe_allow_html=True) | |
# --- Common Memory Functions --- | |
# Cache for 5 minutes | |
def get_global_common_memory(): | |
"""Retrieve the global common memory.""" | |
memory_doc = global_common_memory.find_one({"memory_id": "global_common_memory_id"}) | |
return memory_doc.get('memory', []) if memory_doc else [] | |
def append_to_global_common_memory(content): | |
"""Append content to the global common memory.""" | |
try: | |
# First, ensure the document exists with an initialized memory array | |
global_common_memory.update_one( | |
{"memory_id": GLOBAL_MEMORY_ID}, | |
{"$setOnInsert": {"memory": []}}, | |
upsert=True | |
) | |
# Then, add the new content to the memory array | |
result = global_common_memory.update_one( | |
{"memory_id": GLOBAL_MEMORY_ID}, | |
{"$addToSet": {"memory": content}} | |
) | |
# Invalidate the cache after updating | |
get_global_common_memory.clear() | |
if result.modified_count > 0: | |
st.success("Memory appended successfully!") | |
else: | |
st.info("This memory item already exists or no changes were made.") | |
raise RerunException(RerunData(page_script_hash=None)) | |
except RerunException: | |
raise | |
except Exception as e: | |
st.error(f"Failed to append to global common memory: {str(e)}") | |
def clear_global_common_memory(): | |
"""Clear all items from the global common memory.""" | |
try: | |
global_common_memory.update_one( | |
{"memory_id": GLOBAL_MEMORY_ID}, | |
{"$set": {"memory": []}}, | |
upsert=True | |
) | |
# Invalidate the cache after clearing | |
get_global_common_memory.clear() | |
st.success("Global common memory cleared successfully!") | |
except Exception as e: | |
st.error(f"Failed to clear global common memory: {str(e)}") | |
# --- Takeover Functions --- | |
def activate_takeover(session_id): | |
""" | |
Activates takeover mode for the given session. | |
""" | |
try: | |
db.takeover_status.update_one( | |
{"session_id": session_id}, | |
{"$set": {"active": True, "activated_at": datetime.utcnow()}}, | |
upsert=True | |
) | |
st.success(f"Takeover activated for session {session_id[:8]}...") | |
except Exception as e: | |
st.error(f"Failed to activate takeover: {str(e)}") | |
def deactivate_takeover(session_id): | |
""" | |
Deactivates takeover mode for the given session. | |
""" | |
try: | |
db.takeover_status.update_one( | |
{"session_id": session_id}, | |
{"$set": {"active": False}}, | |
) | |
st.success(f"Takeover deactivated for session {session_id[:8]}...") | |
except Exception as e: | |
st.error(f"Failed to deactivate takeover: {str(e)}") | |
def send_admin_message(session_id, message): | |
""" | |
Sends an admin message directly to the user during a takeover. | |
""" | |
admin_message = { | |
"role": "admin", | |
"content": message, | |
"timestamp": datetime.utcnow(), | |
"status": "approved" | |
} | |
try: | |
# Upsert the admin message | |
result = conversation_history.update_one( | |
{"session_id": session_id}, | |
{ | |
"$push": {"messages": admin_message}, | |
"$set": {"last_updated": datetime.utcnow()} | |
} | |
) | |
if result.modified_count > 0: | |
st.success("Admin message sent successfully!") | |
else: | |
st.error("Failed to send admin message.") | |
except Exception as e: | |
st.error(f"Failed to send admin message: {str(e)}") | |
# --- Admin Dashboard Functions --- | |
def handle_admin_intervention(): | |
st.subheader("Admin Intervention") | |
st.subheader("Review Pending Responses") | |
pending_responses = conversation_history.find( | |
{"messages.role": "assistant", "messages.status": "pending"} | |
) | |
for conversation in pending_responses: | |
st.write(f"Session ID: {conversation['session_id'][:8]}...") | |
# Display global common memory | |
st.subheader("Global Common Memory") | |
common_memory = get_global_common_memory() | |
if common_memory: | |
for idx, item in enumerate(common_memory, 1): | |
st.text(f"{idx}. {item}") | |
else: | |
st.info("Global common memory is currently empty.") | |
for i, message in enumerate(conversation['messages']): | |
if message['role'] == 'assistant' and message.get('status') == 'pending': | |
user_message = conversation['messages'][i-1]['content'] if i > 0 else "N/A" | |
st.write(f"**User:** {user_message}") | |
st.write(f"**GPT:** {message['content']}") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
if st.button("Approve", key=f"approve_{conversation['session_id']}_{i}"): | |
if approve_response(conversation['session_id'], i): | |
st.success("Response approved") | |
time.sleep(0.5) # Short delay to ensure the success message is visible | |
st.rerun() | |
with col2: | |
if st.button("Modify", key=f"modify_{conversation['session_id']}_{i}"): | |
st.session_state['modifying'] = (conversation['session_id'], i) | |
st.rerun() | |
with col3: | |
if st.button("Regenerate", key=f"regenerate_{conversation['session_id']}_{i}"): | |
st.session_state['regenerating'] = (conversation['session_id'], i) | |
st.rerun() | |
st.divider() | |
if 'regenerating' in st.session_state: | |
try: | |
session_id, message_index = st.session_state['regenerating'] | |
with st.form(key="regenerate_form"): | |
operator_input = st.text_input("Enter additional instructions for regeneration:") | |
submit_button = st.form_submit_button("Submit") | |
if submit_button: | |
del st.session_state['regenerating'] # Remove the key after submission | |
regenerate_response(session_id, message_index, operator_input) | |
st.success("Response regenerated with operator input.") | |
st.rerun() | |
except ValueError: | |
st.error("Invalid regenerating state. Please try again.") | |
if 'modifying' in st.session_state: | |
session_id, message_index = st.session_state['modifying'] | |
conversation = conversation_history.find_one({"session_id": session_id}) | |
message = conversation['messages'][message_index] | |
modified_content = st.text_area("Modify the response:", value=message['content']) | |
if st.button("Save Modification"): | |
save_modified_response(session_id, message_index, modified_content) | |
st.success("Response modified and approved") | |
del st.session_state['modifying'] | |
st.rerun() | |
def approve_response(session_id, message_index): | |
try: | |
result = conversation_history.update_one( | |
{"session_id": session_id}, | |
{"$set": {f"messages.{message_index}.status": "approved"}} | |
) | |
return result.modified_count > 0 | |
except Exception as e: | |
st.error(f"Failed to approve response: {str(e)}") | |
return False | |
def save_modified_response(session_id, message_index, modified_content): | |
try: | |
conversation_history.update_one( | |
{"session_id": session_id}, | |
{ | |
"$set": { | |
f"messages.{message_index}.content": modified_content, | |
f"messages.{message_index}.status": "approved" | |
} | |
} | |
) | |
except Exception as e: | |
st.error(f"Failed to save modified response: {str(e)}") | |
def regenerate_response(session_id, message_index, operator_input): | |
try: | |
conversation = conversation_history.find_one({"session_id": session_id}) | |
user_message = conversation['messages'][message_index - 1]['content'] if message_index > 0 else "" | |
new_response, is_uncertain = get_gpt_response(user_message, system_message=operator_input) | |
if is_uncertain: | |
status = "pending" | |
else: | |
status = "approved" | |
conversation_history.update_one( | |
{"session_id": session_id}, | |
{ | |
"$set": { | |
f"messages.{message_index}.content": new_response, | |
f"messages.{message_index}.status": status | |
} | |
} | |
) | |
except Exception as e: | |
st.error(f"Failed to regenerate response: {str(e)}") | |
# --- Admin Page --- | |
def admin_page(): | |
st.title("๐ ๏ธ Operator Dashboard") | |
# Add auto-refresh every 10 seconds (10000 milliseconds) | |
st_autorefresh(interval=10000, limit=None, key="operator_autorefresh") | |
if st.button("๐ Reload Dashboard"): | |
st.rerun() | |
try: | |
deleted_count = cleanup_old_chats() | |
if deleted_count is not None: | |
if deleted_count > 0: | |
st.success(f"Cleaned up {deleted_count} inactive chat(s).") | |
else: | |
st.info("No inactive chats to clean up.") | |
else: | |
st.warning("Unable to perform cleanup. Please check the database connection.") | |
tab1, tab2 = st.tabs([ | |
"๐ Current Chats", | |
"๐ง Admin Intervention", | |
]) | |
with tab1: | |
st.header("Current Chats") | |
recent_chats = fetch_recent_chats() | |
if not recent_chats: | |
st.info("No recent chats found.") | |
else: | |
cols_per_row = 3 | |
for i in range(0, len(recent_chats), cols_per_row): | |
cols = st.columns(cols_per_row) | |
for j, chat in enumerate(recent_chats[i:i + cols_per_row]): | |
with cols[j]: | |
with st.expander(f"Session: {chat['session_id'][:8]}...", expanded=False): | |
display_chat_preview(chat) | |
col1, col2 = st.columns(2) | |
with col1: | |
if st.button("View Full Chat", key=f"view_{chat['session_id']}"): | |
st.session_state['selected_chat'] = chat['session_id'] | |
st.rerun() | |
with col2: | |
if st.button("Delete Chat", key=f"delete_{chat['session_id']}"): | |
delete_chat(chat['session_id']) | |
st.rerun() | |
with tab2: | |
handle_admin_intervention() | |
st.caption(f"Last refreshed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
except (RerunException, StopException): | |
raise | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
# --- Fetch Recent Chats --- | |
def fetch_recent_chats(): | |
return list(conversation_history.find({}, | |
{"session_id": 1, "last_updated": 1, "messages": {"$slice": 3}}) | |
.sort("last_updated", -1) | |
.limit(10)) | |
# --- Display Chat Preview --- | |
def display_chat_preview(chat): | |
st.subheader(f"Session: {chat['session_id'][:8]}...") | |
last_updated = chat.get('last_updated', datetime.utcnow()) | |
st.caption(f"Last updated: {last_updated.strftime('%Y-%m-%d %H:%M:%S')}") | |
for message in chat.get('messages', [])[:3]: | |
with st.chat_message(message['role']): | |
st.markdown(f"**{message['role'].capitalize()}**: {message['content'][:100]}...") | |
st.divider() | |
# --- Delete Chat --- | |
def delete_chat(session_id): | |
try: | |
result = conversation_history.delete_one({"session_id": session_id}) | |
if result.deleted_count > 0: | |
st.success(f"Chat {session_id[:8]}... deleted successfully.") | |
else: | |
st.error("Failed to delete chat. Please try again.") | |
except Exception as e: | |
st.error(f"Error deleting chat: {str(e)}") | |
# --- Cleanup Old Chats --- | |
def cleanup_old_chats(): | |
try: | |
cutoff_time = datetime.utcnow() - timedelta(minutes=5) | |
result = conversation_history.delete_many({"last_updated": {"$lt": cutoff_time}}) | |
return result.deleted_count | |
except Exception as e: | |
print(f"Error during chat cleanup: {str(e)}") | |
return None | |
# --- GPT Response Function --- | |
def get_gpt_response(prompt, context="", system_message=None): | |
""" | |
Generates a response from the GPT model based on the user prompt and retrieved context. | |
Incorporates the global common memory and optional system message. | |
Returns a tuple of (response, is_uncertain). | |
""" | |
try: | |
common_memory = get_global_common_memory() | |
system_msg = ( | |
"You are a helpful assistant. Use the following context and global common memory " | |
"to inform your responses, but don't mention them explicitly unless directly relevant to the user's question." | |
) | |
if system_message: | |
system_msg += f"\n\nOperator Instructions:\n{system_message}" | |
if common_memory: | |
memory_str = "\n".join(common_memory) | |
system_msg += f"\n\nGlobal Common Memory:\n{memory_str}" | |
messages = [ | |
{"role": "system", "content": system_msg}, | |
{"role": "user", "content": f"Context: {context}\n\nUser query: {prompt}"} | |
] | |
completion = openai_client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=messages | |
) | |
print(completion) | |
response = completion.choices[0].message.content.strip() | |
# TODO: Implement your logic to determine if the response is uncertain | |
is_uncertain = False # Example placeholder | |
return response, is_uncertain | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return None, True # Indicates uncertainty due to error | |
# --- View Full Chat Function --- | |
def view_full_chat(session_id): | |
"""Display the full chat and provide takeover functionality.""" | |
# Add a "Go to Dashboard" button at the top | |
if st.button("๐ Go to Dashboard"): | |
st.session_state.pop('selected_chat', None) | |
st.rerun() | |
conversation = conversation_history.find_one({"session_id": session_id}) | |
if not conversation: | |
st.error("Chat not found.") | |
return | |
st.header(f"Full Chat - Session ID: {conversation['session_id'][:8]}...") | |
st.caption(f"Last updated: {conversation.get('last_updated', datetime.utcnow()).strftime('%Y-%m-%d %H:%M:%S')}") | |
for message in conversation.get('messages', []): | |
with st.chat_message(message['role']): | |
st.markdown(f"**{message['role'].capitalize()}**: {message['content']}") | |
# Takeover functionality | |
takeover_doc = db.takeover_status.find_one({"session_id": session_id}) | |
takeover_active = takeover_doc.get("active", False) if takeover_doc else False | |
if takeover_active: | |
if st.button("Deactivate Takeover"): | |
deactivate_takeover(session_id) | |
st.success("Takeover deactivated.") | |
st.rerun() | |
else: | |
if st.button("Activate Takeover"): | |
activate_takeover(session_id) | |
st.success("Takeover activated.") | |
st.rerun() | |
# If takeover is active, allow operator to send messages | |
if takeover_active: | |
with st.form(key=f"admin_message_form_{session_id}"): | |
admin_message = st.text_input("Enter message to send to the user:") | |
submit_button = st.form_submit_button("Send Message") | |
if submit_button and admin_message: | |
send_admin_message(session_id, admin_message) | |
st.success("Admin message sent.") | |
st.rerun() | |
# --- Main Function --- | |
def main(): | |
try: | |
if 'selected_chat' in st.session_state: | |
view_full_chat(st.session_state['selected_chat']) | |
else: | |
admin_page() | |
except (RerunException, StopException): | |
raise | |
except Exception as e: | |
st.error(f"An unexpected error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |