JustKiddo commited on
Commit
9d2a89b
·
verified ·
1 Parent(s): d95cfa5

Create app.bak

Browse files
Files changed (1) hide show
  1. app.bak +313 -0
app.bak ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ from bertopic import BERTopic
4
+ from sentence_transformers import SentenceTransformer
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ import pandas as pd
8
+ import plotly.graph_objects as go
9
+ from datetime import datetime
10
+ import json
11
+ from collections import deque
12
+ from datasets import load_dataset
13
+
14
+ class BERTopicChatbot:
15
+
16
+ #Initialize chatbot with a Hugging Face dataset
17
+ #dataset_name: name of the dataset on Hugging Face (e.g., 'vietnam/legal')
18
+ #text_column: name of the column containing the text data
19
+ #split: which split of the dataset to use ('train', 'test', 'validation')
20
+ #max_samples: maximum number of samples to use (to manage memory)
21
+
22
+ def __init__(self, dataset_name, text_column, split="train", max_samples=10000):
23
+ # Initialize BERT sentence transformer
24
+ self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
25
+ # Load dataset from Hugging Face
26
+ try:
27
+ dataset = load_dataset(dataset_name, split=split)
28
+ # Convert to pandas DataFrame and sample if necessary
29
+ if len(dataset) > max_samples:
30
+ dataset = dataset.shuffle(seed=42).select(range(max_samples))
31
+
32
+ self.df = dataset.to_pandas()
33
+
34
+ # Ensure text column exists
35
+ if text_column not in self.df.columns:
36
+ raise ValueError(f"Column '{text_column}' not found in dataset. Available columns: {self.df.columns}")
37
+
38
+ self.documents = self.df[text_column].tolist()
39
+
40
+ # Create and train BERTopic model
41
+ self.topic_model = BERTopic(embedding_model=self.sentence_model)
42
+ self.topics, self.probs = self.topic_model.fit_transform(self.documents)
43
+
44
+ # Create document embeddings for similarity search
45
+ self.doc_embeddings = self.sentence_model.encode(self.documents)
46
+
47
+ # Initialize metrics storage
48
+ self.metrics_history = {
49
+ 'similarities': deque(maxlen=100),
50
+ 'response_times': deque(maxlen=100),
51
+ 'token_counts': deque(maxlen=100),
52
+ 'topics_accessed': {}
53
+ }
54
+
55
+ # Store dataset info
56
+ self.dataset_info = {
57
+ 'name': dataset_name,
58
+ 'split': split,
59
+ 'total_documents': len(self.documents),
60
+ 'topics_found': len(set(self.topics))
61
+ }
62
+
63
+ except Exception as e:
64
+ st.error(f"Error loading dataset: {str(e)}")
65
+ raise
66
+
67
+ def get_metrics_visualizations(self):
68
+ """Generate visualizations for chatbot metrics"""
69
+ # Similarity trend
70
+ fig_similarity = go.Figure()
71
+ fig_similarity.add_trace(go.Scatter(
72
+ y=list(self.metrics_history['similarities']),
73
+ mode='lines+markers',
74
+ name='Similarity Score'
75
+ ))
76
+ fig_similarity.update_layout(
77
+ title='Response Similarity Trend',
78
+ yaxis_title='Similarity Score',
79
+ xaxis_title='Query Number'
80
+ )
81
+
82
+ # Response time trend
83
+ fig_response_time = go.Figure()
84
+ fig_response_time.add_trace(go.Scatter(
85
+ y=list(self.metrics_history['response_times']),
86
+ mode='lines+markers',
87
+ name='Response Time'
88
+ ))
89
+ fig_response_time.update_layout(
90
+ title='Response Time Trend',
91
+ yaxis_title='Time (seconds)',
92
+ xaxis_title='Query Number'
93
+ )
94
+
95
+ # Token usage trend
96
+ fig_tokens = go.Figure()
97
+ fig_tokens.add_trace(go.Scatter(
98
+ y=list(self.metrics_history['token_counts']),
99
+ mode='lines+markers',
100
+ name='Token Count'
101
+ ))
102
+ fig_tokens.update_layout(
103
+ title='Token Usage Trend',
104
+ yaxis_title='Number of Tokens',
105
+ xaxis_title='Query Number'
106
+ )
107
+
108
+ # Topics accessed pie chart
109
+ labels = list(self.metrics_history['topics_accessed'].keys())
110
+ values = list(self.metrics_history['topics_accessed'].values())
111
+ fig_topics = go.Figure(data=[go.Pie(labels=labels, values=values)])
112
+ fig_topics.update_layout(title='Topics Accessed Distribution')
113
+
114
+ # Make all figures responsive
115
+ for fig in [fig_similarity, fig_response_time, fig_tokens, fig_topics]:
116
+ fig.update_layout(
117
+ autosize=True,
118
+ margin=dict(l=20, r=20, t=40, b=20),
119
+ height=300
120
+ )
121
+
122
+ return fig_similarity, fig_response_time, fig_tokens, fig_topics
123
+
124
+ def get_most_similar_document(self, query, top_k=3):
125
+ # Encode the query
126
+ query_embedding = self.sentence_model.encode([query])[0]
127
+
128
+ # Calculate similarities
129
+ similarities = cosine_similarity([query_embedding], self.doc_embeddings)[0]
130
+
131
+ # Get top k most similar documents
132
+ top_indices = similarities.argsort()[-top_k:][::-1]
133
+
134
+ return [self.documents[i] for i in top_indices], similarities[top_indices]
135
+
136
+ def get_response(self, user_query):
137
+ try:
138
+ start_time = datetime.now()
139
+
140
+ # Get most similar documents
141
+ similar_docs, similarities = self.get_most_similar_document(user_query)
142
+
143
+ # Get topic for the query
144
+ query_topic, _ = self.topic_model.transform([user_query])
145
+
146
+ # Track topic access
147
+ topic_id = str(query_topic[0])
148
+ self.metrics_history['topics_accessed'][topic_id] = \
149
+ self.metrics_history['topics_accessed'].get(topic_id, 0) + 1
150
+
151
+ # If similarity is too low, return a default response
152
+ if max(similarities) < 0.5:
153
+ response = "Xin lỗi, tôi không có đủ thông tin để trả lời câu hỏi này một cách chính xác."
154
+ else:
155
+ response = similar_docs[0]
156
+
157
+ # Track metrics
158
+ end_time = datetime.now()
159
+ self.metrics_history['similarities'].append(float(max(similarities)))
160
+ self.metrics_history['response_times'].append((end_time - start_time).total_seconds())
161
+ self.metrics_history['token_counts'].append(len(response.split()))
162
+
163
+ metrics = {
164
+ 'similarity': float(max(similarities)),
165
+ 'response_time': (end_time - start_time).total_seconds(),
166
+ 'tokens': len(response.split()),
167
+ 'topic': topic_id
168
+ }
169
+
170
+ return response, metrics
171
+
172
+ except Exception as e:
173
+ return f"Error processing query: {str(e)}", {'error': str(e)}
174
+
175
+ def get_dataset_info(self):
176
+ #Return information about the loaded dataset and metrics
177
+ try:
178
+ return {
179
+ 'dataset_info': self.dataset_info,
180
+ 'metrics': {
181
+ 'avg_similarity': np.mean(list(self.metrics_history['similarities'])) if self.metrics_history['similarities'] else 0,
182
+ 'avg_response_time': np.mean(list(self.metrics_history['response_times'])) if self.metrics_history['response_times'] else 0,
183
+ 'total_tokens': sum(self.metrics_history['token_counts']),
184
+ 'topics_accessed': self.metrics_history['topics_accessed']
185
+ }
186
+ }
187
+ except Exception as e:
188
+ return {
189
+ 'error': str(e),
190
+ 'dataset_info': None,
191
+ 'metrics': None
192
+ }
193
+
194
+ @st.cache_resource
195
+ def initialize_chatbot(dataset_name, text_column, split="train", max_samples=10000):
196
+ return BERTopicChatbot(dataset_name, text_column, split, max_samples)
197
+
198
+ def main():
199
+ st.title("🤖 Trợ Lý AI - BERTopic")
200
+ st.caption("Trò chuyện với chúng mình nhé!")
201
+
202
+ # Dataset selection sidebar
203
+ with st.sidebar:
204
+ st.header("Dataset Configuration")
205
+ dataset_name = st.text_input(
206
+ "Hugging Face Dataset Name",
207
+ value="Kanakmi/mental-disorders",
208
+ help="Enter the name of a dataset from Hugging Face (e.g., 'Kanakmi/mental-disorders')"
209
+ )
210
+ text_column = st.text_input(
211
+ "Text Column Name",
212
+ value="text",
213
+ help="Enter the name of the column containing the text data"
214
+ )
215
+ split = st.selectbox(
216
+ "Dataset Split",
217
+ options=["train", "test", "validation"],
218
+ index=0
219
+ )
220
+ max_samples = st.number_input(
221
+ "Maximum Samples",
222
+ min_value=100,
223
+ max_value=100000,
224
+ value=10000,
225
+ step=1000,
226
+ help="Maximum number of samples to load from the dataset"
227
+ )
228
+
229
+ if st.button("Load Dataset"):
230
+ with st.spinner("Loading dataset and initializing model..."):
231
+ try:
232
+ st.session_state.chatbot = initialize_chatbot(
233
+ dataset_name, text_column, split, max_samples
234
+ )
235
+ st.success("Dataset loaded successfully!")
236
+ except Exception as e:
237
+ st.error(f"Error loading dataset: {str(e)}")
238
+
239
+ # Initialize session state variables if they don't exist
240
+ if 'chatbot' not in st.session_state:
241
+ st.session_state.chatbot = None
242
+
243
+ if 'messages' not in st.session_state:
244
+ st.session_state.messages = []
245
+
246
+ # Create tabs for chat and metrics
247
+ chat_tab, metrics_tab = st.tabs(["Chat", "Metrics"])
248
+
249
+ with chat_tab:
250
+ # Display existing messages
251
+ for message in st.session_state.messages:
252
+ with st.chat_message(message["role"]):
253
+ st.markdown(message["content"])
254
+
255
+ # Only show chat input if chatbot is initialized
256
+ if st.session_state.chatbot is not None:
257
+ if prompt := st.chat_input("Hãy nói gì đó..."):
258
+ # Add user message
259
+ st.session_state.messages.append({"role": "user", "content": prompt})
260
+ with st.chat_message("user"):
261
+ st.markdown(prompt)
262
+
263
+ # Get chatbot response
264
+ response, metrics = st.session_state.chatbot.get_response(prompt)
265
+
266
+ # Add assistant response
267
+ with st.chat_message("assistant"):
268
+ st.markdown(response)
269
+ with st.expander("Response Metrics"):
270
+ st.json(metrics)
271
+
272
+ st.session_state.messages.append({"role": "assistant", "content": response})
273
+ else:
274
+ st.info("Please load a dataset first to start chatting.")
275
+
276
+ with metrics_tab:
277
+ if st.session_state.chatbot is not None:
278
+ try:
279
+ # Get visualizations from session state chatbot
280
+ fig_similarity, fig_response_time, fig_tokens, fig_topics = st.session_state.chatbot.get_metrics_visualizations()
281
+
282
+ col1, col2 = st.columns(2)
283
+ with col1:
284
+ st.plotly_chart(fig_similarity, use_container_width=True)
285
+ st.plotly_chart(fig_tokens, use_container_width=True)
286
+
287
+ with col2:
288
+ st.plotly_chart(fig_response_time, use_container_width=True)
289
+ st.plotly_chart(fig_topics, use_container_width=True)
290
+
291
+ # Display statistics
292
+ st.subheader("Overall Statistics")
293
+ metrics_history = st.session_state.chatbot.metrics_history
294
+ if len(metrics_history['similarities']) > 0:
295
+ stats_col1, stats_col2, stats_col3 = st.columns(3)
296
+ with stats_col1:
297
+ st.metric("Avg Similarity",
298
+ f"{np.mean(list(metrics_history['similarities'])):.3f}")
299
+ with stats_col2:
300
+ st.metric("Avg Response Time",
301
+ f"{np.mean(list(metrics_history['response_times'])):.3f}s")
302
+ with stats_col3:
303
+ st.metric("Total Tokens Used",
304
+ sum(metrics_history['token_counts']))
305
+ else:
306
+ st.info("No chat history available yet. Start a conversation to see metrics.")
307
+ except Exception as e:
308
+ st.error(f"Error displaying metrics: {str(e)}")
309
+ else:
310
+ st.info("Please load a dataset first to view metrics.")
311
+
312
+ if __name__ == "__main__":
313
+ main()