import streamlit as st import requests from bertopic import BERTopic from sentence_transformers import SentenceTransformer import numpy as np from sklearn.metrics.pairwise import cosine_similarity import pandas as pd import plotly.graph_objects as go from datetime import datetime import json from collections import deque from datasets import load_dataset class BERTopicChatbot: #Initialize chatbot with a Hugging Face dataset #dataset_name: name of the dataset on Hugging Face (e.g., 'vietnam/legal') #text_column: name of the column containing the text data #split: which split of the dataset to use ('train', 'test', 'validation') #max_samples: maximum number of samples to use (to manage memory) def __init__(self, dataset_name, text_column, split="train", max_samples=10000): # Initialize BERT sentence transformer self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2') # Load dataset from Hugging Face try: dataset = load_dataset(dataset_name, split=split) # Convert to pandas DataFrame and sample if necessary if len(dataset) > max_samples: dataset = dataset.shuffle(seed=42).select(range(max_samples)) self.df = dataset.to_pandas() # Ensure text column exists if text_column not in self.df.columns: raise ValueError(f"Column '{text_column}' not found in dataset. Available columns: {self.df.columns}") self.documents = self.df[text_column].tolist() # Create and train BERTopic model self.topic_model = BERTopic(embedding_model=self.sentence_model) self.topics, self.probs = self.topic_model.fit_transform(self.documents) # Create document embeddings for similarity search self.doc_embeddings = self.sentence_model.encode(self.documents) # Initialize metrics storage self.metrics_history = { 'similarities': deque(maxlen=100), 'response_times': deque(maxlen=100), 'token_counts': deque(maxlen=100), 'topics_accessed': {} } # Store dataset info self.dataset_info = { 'name': dataset_name, 'split': split, 'total_documents': len(self.documents), 'topics_found': len(set(self.topics)) } except Exception as e: st.error(f"Error loading dataset: {str(e)}") raise def get_metrics_visualizations(self): """Generate visualizations for chatbot metrics""" # Similarity trend fig_similarity = go.Figure() fig_similarity.add_trace(go.Scatter( y=list(self.metrics_history['similarities']), mode='lines+markers', name='Similarity Score' )) fig_similarity.update_layout( title='Response Similarity Trend', yaxis_title='Similarity Score', xaxis_title='Query Number' ) # Response time trend fig_response_time = go.Figure() fig_response_time.add_trace(go.Scatter( y=list(self.metrics_history['response_times']), mode='lines+markers', name='Response Time' )) fig_response_time.update_layout( title='Response Time Trend', yaxis_title='Time (seconds)', xaxis_title='Query Number' ) # Token usage trend fig_tokens = go.Figure() fig_tokens.add_trace(go.Scatter( y=list(self.metrics_history['token_counts']), mode='lines+markers', name='Token Count' )) fig_tokens.update_layout( title='Token Usage Trend', yaxis_title='Number of Tokens', xaxis_title='Query Number' ) # Topics accessed pie chart labels = list(self.metrics_history['topics_accessed'].keys()) values = list(self.metrics_history['topics_accessed'].values()) fig_topics = go.Figure(data=[go.Pie(labels=labels, values=values)]) fig_topics.update_layout(title='Topics Accessed Distribution') # Make all figures responsive for fig in [fig_similarity, fig_response_time, fig_tokens, fig_topics]: fig.update_layout( autosize=True, margin=dict(l=20, r=20, t=40, b=20), height=300 ) return fig_similarity, fig_response_time, fig_tokens, fig_topics def get_most_similar_document(self, query, top_k=3): # Encode the query query_embedding = self.sentence_model.encode([query])[0] # Calculate similarities similarities = cosine_similarity([query_embedding], self.doc_embeddings)[0] # Get top k most similar documents top_indices = similarities.argsort()[-top_k:][::-1] return [self.documents[i] for i in top_indices], similarities[top_indices] def get_response(self, user_query): try: start_time = datetime.now() # Get most similar documents similar_docs, similarities = self.get_most_similar_document(user_query) # Get topic for the query query_topic, _ = self.topic_model.transform([user_query]) # Track topic access topic_id = str(query_topic[0]) self.metrics_history['topics_accessed'][topic_id] = \ self.metrics_history['topics_accessed'].get(topic_id, 0) + 1 # If similarity is too low, return a default response if max(similarities) < 0.5: response = "Xin lỗi, tôi không có đủ thông tin để trả lời câu hỏi này một cách chính xác." else: response = similar_docs[0] # Track metrics end_time = datetime.now() self.metrics_history['similarities'].append(float(max(similarities))) self.metrics_history['response_times'].append((end_time - start_time).total_seconds()) self.metrics_history['token_counts'].append(len(response.split())) metrics = { 'similarity': float(max(similarities)), 'response_time': (end_time - start_time).total_seconds(), 'tokens': len(response.split()), 'topic': topic_id } return response, metrics except Exception as e: return f"Error processing query: {str(e)}", {'error': str(e)} def get_dataset_info(self): #Return information about the loaded dataset and metrics try: return { 'dataset_info': self.dataset_info, 'metrics': { 'avg_similarity': np.mean(list(self.metrics_history['similarities'])) if self.metrics_history['similarities'] else 0, 'avg_response_time': np.mean(list(self.metrics_history['response_times'])) if self.metrics_history['response_times'] else 0, 'total_tokens': sum(self.metrics_history['token_counts']), 'topics_accessed': self.metrics_history['topics_accessed'] } } except Exception as e: return { 'error': str(e), 'dataset_info': None, 'metrics': None } @st.cache_resource def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000): return BERTopicChatbot(dataset_name, text_column, split, max_samples) def main(): st.title("🤖 Trợ Lý AI - BERTopic") st.caption("Trò chuyện với chúng mình nhé!") # Dataset selection sidebar with st.sidebar: st.header("Dataset Configuration") dataset_name = st.text_input( "Hugging Face Dataset Name", value="Kanakmi/mental-disorders", help="Enter the name of a dataset from Hugging Face (e.g., 'Kanakmi/mental-disorders')" ) text_column = st.text_input( "Text Column Name", value="text", help="Enter the name of the column containing the text data" ) split = st.selectbox( "Dataset Split", options=["train", "test", "validation"], index=0 ) max_samples = st.number_input( "Maximum Samples", min_value=100, max_value=100000, value=10000, step=1000, help="Maximum number of samples to load from the dataset" ) if st.button("Load Dataset"): with st.spinner("Loading dataset and initializing model..."): try: st.session_state.chatbot = initialize_chatbot( dataset_name, text_column, split, max_samples ) st.success("Dataset loaded successfully!") except Exception as e: st.error(f"Error loading dataset: {str(e)}") # Initialize session state variables if they don't exist if 'chatbot' not in st.session_state: st.session_state.chatbot = None if 'messages' not in st.session_state: st.session_state.messages = [] # Create tabs for chat and metrics chat_tab, metrics_tab = st.tabs(["Chat", "Metrics"]) with chat_tab: # Display existing messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Only show chat input if chatbot is initialized if st.session_state.chatbot is not None: if prompt := st.chat_input("Hãy nói gì đó..."): # Add user message st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Get chatbot response response, metrics = st.session_state.chatbot.get_response(prompt) # Add assistant response with st.chat_message("assistant"): st.markdown(response) with st.expander("Response Metrics"): st.json(metrics) st.session_state.messages.append({"role": "assistant", "content": response}) else: st.info("Please load a dataset first to start chatting.") with metrics_tab: if st.session_state.chatbot is not None: try: # Get visualizations from session state chatbot fig_similarity, fig_response_time, fig_tokens, fig_topics = st.session_state.chatbot.get_metrics_visualizations() col1, col2 = st.columns(2) with col1: st.plotly_chart(fig_similarity, use_container_width=True) st.plotly_chart(fig_tokens, use_container_width=True) with col2: st.plotly_chart(fig_response_time, use_container_width=True) st.plotly_chart(fig_topics, use_container_width=True) # Display statistics st.subheader("Overall Statistics") metrics_history = st.session_state.chatbot.metrics_history if len(metrics_history['similarities']) > 0: stats_col1, stats_col2, stats_col3 = st.columns(3) with stats_col1: st.metric("Avg Similarity", f"{np.mean(list(metrics_history['similarities'])):.3f}") with stats_col2: st.metric("Avg Response Time", f"{np.mean(list(metrics_history['response_times'])):.3f}s") with stats_col3: st.metric("Total Tokens Used", sum(metrics_history['token_counts'])) else: st.info("No chat history available yet. Start a conversation to see metrics.") except Exception as e: st.error(f"Error displaying metrics: {str(e)}") else: st.info("Please load a dataset first to view metrics.") if __name__ == "__main__": main()