JustKiddo commited on
Commit
2bc9c40
·
verified ·
1 Parent(s): a4c3bcc

Updated BERT models.

Browse files
Files changed (1) hide show
  1. app.py +61 -26
app.py CHANGED
@@ -2,9 +2,9 @@ import streamlit as st
2
  from transformers import BertForSequenceClassification, BertTokenizer
3
  import torch
4
  import time
5
- import streamlit.components.v1 as components
6
 
7
- # Custom CSS for chat interface
8
  def local_css():
9
  st.markdown("""
10
  <style>
@@ -71,7 +71,6 @@ def local_css():
71
  </style>
72
  """, unsafe_allow_html=True)
73
 
74
- # Load model and tokenizer
75
  @st.cache_resource
76
  def load_model():
77
  model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
@@ -85,19 +84,56 @@ def predict(text, model, tokenizer):
85
  outputs = model(**inputs)
86
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
87
  predicted_class = torch.argmax(predictions, dim=1).item()
 
88
 
89
- return predicted_class, predictions[0]
90
 
91
- def get_bot_response(predicted_class, confidence):
92
- # Customize these responses based on your model's classes
93
  responses = {
94
- 0: "I understand this is about [Class 0]. Let me help you with that.",
95
- 1: "This seems to be related to [Class 1]. Here's what I can tell you.",
96
- # Add more responses for your classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  }
 
 
 
 
 
 
98
 
99
- default_response = "I'm not quite sure about that. Could you please rephrase?"
100
- return responses.get(predicted_class, default_response)
 
 
 
 
 
 
 
 
 
101
 
102
  def init_session_state():
103
  if 'messages' not in st.session_state:
@@ -113,7 +149,7 @@ def display_chat_history():
113
  st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
114
 
115
  def main():
116
- st.set_page_config(page_title="AI Chatbot", page_icon="🤖", layout="wide")
117
  local_css()
118
  init_session_state()
119
 
@@ -121,46 +157,45 @@ def main():
121
  model, tokenizer = load_model()
122
 
123
  # Chat interface
124
- st.title("AI Chatbot 🤖")
125
- st.markdown("Welcome! I'm here to help answer your questions in Vietnamese.")
126
 
127
  # Chat history container
128
  chat_container = st.container()
129
 
130
- # Input container at the bottom
131
  with st.container():
132
  col1, col2 = st.columns([6, 1])
133
  with col1:
134
- user_input = st.text_input("Type your message...", key="user_input", label_visibility="hidden")
135
  with col2:
136
- send_button = st.button("Send")
137
 
138
  if user_input and send_button:
139
- # Add user message to chat
140
  st.session_state.messages.append({"role": "user", "content": user_input})
141
 
142
  # Show thinking animation
143
  st.session_state.thinking = True
144
 
145
- # Get model prediction
146
- predicted_class, probabilities = predict(user_input, model, tokenizer)
147
 
148
- # Get bot response
149
- bot_response = get_bot_response(predicted_class, probabilities)
150
 
151
- # Add bot response to chat
152
- time.sleep(1) # Simulate processing time
153
  st.session_state.messages.append({"role": "assistant", "content": bot_response})
154
  st.session_state.thinking = False
155
 
156
- # Clear input
157
  st.rerun()
158
 
159
  # Display chat history
160
  with chat_container:
161
  display_chat_history()
162
 
163
- # Show thinking animation
164
  if st.session_state.thinking:
165
  st.markdown("""
166
  <div class="thinking-animation">
 
2
  from transformers import BertForSequenceClassification, BertTokenizer
3
  import torch
4
  import time
5
+ import random
6
 
7
+ # [Previous CSS styles remain the same]
8
  def local_css():
9
  st.markdown("""
10
  <style>
 
71
  </style>
72
  """, unsafe_allow_html=True)
73
 
 
74
  @st.cache_resource
75
  def load_model():
76
  model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
 
84
  outputs = model(**inputs)
85
  predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
86
  predicted_class = torch.argmax(predictions, dim=1).item()
87
+ confidence = predictions[0][predicted_class].item()
88
 
89
+ return predicted_class, confidence
90
 
91
+ def get_bot_response(text, predicted_class, confidence):
92
+ # Define response templates based on classes and confidence levels
93
  responses = {
94
+ 0: { # Example for class 0 (positive sentiment)
95
+ 'high_conf': [
96
+ "Tôi cảm nhận được sự tích cực trong câu nói của bạn. Xin chia sẻ thêm nhé!",
97
+ "Thật vui khi nghe điều đó. Bạn có thể kể thêm không?",
98
+ "Tuyệt vời! Tôi rất đồng ý với bạn về điều này."
99
+ ],
100
+ 'low_conf': [
101
+ "Có vẻ như đây là điều tích cực. Đúng không nhỉ?",
102
+ "Tôi nghĩ đây là một góc nhìn thú vị đấy.",
103
+ "Nghe có vẻ tốt đấy, bạn nghĩ sao?"
104
+ ]
105
+ },
106
+ 1: { # Example for class 1 (negative sentiment)
107
+ 'high_conf': [
108
+ "Tôi hiểu đây là điều khó khăn với bạn. Hãy chia sẻ thêm nhé.",
109
+ "Tôi rất tiếc khi nghe điều này. Bạn cần tôi giúp gì không?",
110
+ "Đúng là một tình huống khó khăn. Chúng ta cùng tìm giải pháp nhé."
111
+ ],
112
+ 'low_conf': [
113
+ "Có vẻ như bạn đang gặp khó khăn. Tôi có hiểu đúng không?",
114
+ "Tôi không chắc mình hiểu hết, bạn có thể giải thích thêm được không?",
115
+ "Hãy chia sẻ thêm để tôi có thể hiểu rõ hơn nhé."
116
+ ]
117
+ }
118
  }
119
+
120
+ # Add more classes based on your model's output
121
+
122
+ # Determine confidence level
123
+ confidence_threshold = 0.8
124
+ conf_level = 'high_conf' if confidence > confidence_threshold else 'low_conf'
125
 
126
+ # Get appropriate response list
127
+ try:
128
+ response_list = responses[predicted_class][conf_level]
129
+ response = random.choice(response_list)
130
+ except KeyError:
131
+ response = "Xin lỗi, tôi không chắc chắn về điều này. Bạn có thể giải thích rõ hơn được không?"
132
+
133
+ # Add context from user's input
134
+ context_response = f"{response}"
135
+
136
+ return context_response
137
 
138
  def init_session_state():
139
  if 'messages' not in st.session_state:
 
149
  st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
150
 
151
  def main():
152
+ st.set_page_config(page_title="Vietnamese Chatbot", page_icon="🤖", layout="wide")
153
  local_css()
154
  init_session_state()
155
 
 
157
  model, tokenizer = load_model()
158
 
159
  # Chat interface
160
+ st.title("Chatbot Tiếng Việt 🤖")
161
+ st.markdown("Xin chào! Tôi thể giúp cho bạn?")
162
 
163
  # Chat history container
164
  chat_container = st.container()
165
 
166
+ # Input container
167
  with st.container():
168
  col1, col2 = st.columns([6, 1])
169
  with col1:
170
+ user_input = st.text_input("Nhập tin nhắn của bạn...", key="user_input", label_visibility="hidden")
171
  with col2:
172
+ send_button = st.button("Gửi")
173
 
174
  if user_input and send_button:
175
+ # Add user message
176
  st.session_state.messages.append({"role": "user", "content": user_input})
177
 
178
  # Show thinking animation
179
  st.session_state.thinking = True
180
 
181
+ # Get prediction
182
+ predicted_class, confidence = predict(user_input, model, tokenizer)
183
 
184
+ # Generate response
185
+ bot_response = get_bot_response(user_input, predicted_class, confidence)
186
 
187
+ # Add bot response
188
+ time.sleep(0.5) # Brief delay for natural feeling
189
  st.session_state.messages.append({"role": "assistant", "content": bot_response})
190
  st.session_state.thinking = False
191
 
192
+ # Clear input and rerun
193
  st.rerun()
194
 
195
  # Display chat history
196
  with chat_container:
197
  display_chat_history()
198
 
 
199
  if st.session_state.thinking:
200
  st.markdown("""
201
  <div class="thinking-animation">