Spaces:
Sleeping
Sleeping
Updated BERT models.
Browse files
app.py
CHANGED
@@ -2,9 +2,9 @@ import streamlit as st
|
|
2 |
from transformers import BertForSequenceClassification, BertTokenizer
|
3 |
import torch
|
4 |
import time
|
5 |
-
import
|
6 |
|
7 |
-
#
|
8 |
def local_css():
|
9 |
st.markdown("""
|
10 |
<style>
|
@@ -71,7 +71,6 @@ def local_css():
|
|
71 |
</style>
|
72 |
""", unsafe_allow_html=True)
|
73 |
|
74 |
-
# Load model and tokenizer
|
75 |
@st.cache_resource
|
76 |
def load_model():
|
77 |
model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
|
@@ -85,19 +84,56 @@ def predict(text, model, tokenizer):
|
|
85 |
outputs = model(**inputs)
|
86 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
87 |
predicted_class = torch.argmax(predictions, dim=1).item()
|
|
|
88 |
|
89 |
-
return predicted_class,
|
90 |
|
91 |
-
def get_bot_response(predicted_class, confidence):
|
92 |
-
#
|
93 |
responses = {
|
94 |
-
0:
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
def init_session_state():
|
103 |
if 'messages' not in st.session_state:
|
@@ -113,7 +149,7 @@ def display_chat_history():
|
|
113 |
st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
|
114 |
|
115 |
def main():
|
116 |
-
st.set_page_config(page_title="
|
117 |
local_css()
|
118 |
init_session_state()
|
119 |
|
@@ -121,46 +157,45 @@ def main():
|
|
121 |
model, tokenizer = load_model()
|
122 |
|
123 |
# Chat interface
|
124 |
-
st.title("
|
125 |
-
st.markdown("
|
126 |
|
127 |
# Chat history container
|
128 |
chat_container = st.container()
|
129 |
|
130 |
-
# Input container
|
131 |
with st.container():
|
132 |
col1, col2 = st.columns([6, 1])
|
133 |
with col1:
|
134 |
-
user_input = st.text_input("
|
135 |
with col2:
|
136 |
-
send_button = st.button("
|
137 |
|
138 |
if user_input and send_button:
|
139 |
-
# Add user message
|
140 |
st.session_state.messages.append({"role": "user", "content": user_input})
|
141 |
|
142 |
# Show thinking animation
|
143 |
st.session_state.thinking = True
|
144 |
|
145 |
-
# Get
|
146 |
-
predicted_class,
|
147 |
|
148 |
-
#
|
149 |
-
bot_response = get_bot_response(predicted_class,
|
150 |
|
151 |
-
# Add bot response
|
152 |
-
time.sleep(
|
153 |
st.session_state.messages.append({"role": "assistant", "content": bot_response})
|
154 |
st.session_state.thinking = False
|
155 |
|
156 |
-
# Clear input
|
157 |
st.rerun()
|
158 |
|
159 |
# Display chat history
|
160 |
with chat_container:
|
161 |
display_chat_history()
|
162 |
|
163 |
-
# Show thinking animation
|
164 |
if st.session_state.thinking:
|
165 |
st.markdown("""
|
166 |
<div class="thinking-animation">
|
|
|
2 |
from transformers import BertForSequenceClassification, BertTokenizer
|
3 |
import torch
|
4 |
import time
|
5 |
+
import random
|
6 |
|
7 |
+
# [Previous CSS styles remain the same]
|
8 |
def local_css():
|
9 |
st.markdown("""
|
10 |
<style>
|
|
|
71 |
</style>
|
72 |
""", unsafe_allow_html=True)
|
73 |
|
|
|
74 |
@st.cache_resource
|
75 |
def load_model():
|
76 |
model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
|
|
|
84 |
outputs = model(**inputs)
|
85 |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
86 |
predicted_class = torch.argmax(predictions, dim=1).item()
|
87 |
+
confidence = predictions[0][predicted_class].item()
|
88 |
|
89 |
+
return predicted_class, confidence
|
90 |
|
91 |
+
def get_bot_response(text, predicted_class, confidence):
|
92 |
+
# Define response templates based on classes and confidence levels
|
93 |
responses = {
|
94 |
+
0: { # Example for class 0 (positive sentiment)
|
95 |
+
'high_conf': [
|
96 |
+
"Tôi cảm nhận được sự tích cực trong câu nói của bạn. Xin chia sẻ thêm nhé!",
|
97 |
+
"Thật vui khi nghe điều đó. Bạn có thể kể thêm không?",
|
98 |
+
"Tuyệt vời! Tôi rất đồng ý với bạn về điều này."
|
99 |
+
],
|
100 |
+
'low_conf': [
|
101 |
+
"Có vẻ như đây là điều tích cực. Đúng không nhỉ?",
|
102 |
+
"Tôi nghĩ đây là một góc nhìn thú vị đấy.",
|
103 |
+
"Nghe có vẻ tốt đấy, bạn nghĩ sao?"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
1: { # Example for class 1 (negative sentiment)
|
107 |
+
'high_conf': [
|
108 |
+
"Tôi hiểu đây là điều khó khăn với bạn. Hãy chia sẻ thêm nhé.",
|
109 |
+
"Tôi rất tiếc khi nghe điều này. Bạn cần tôi giúp gì không?",
|
110 |
+
"Đúng là một tình huống khó khăn. Chúng ta cùng tìm giải pháp nhé."
|
111 |
+
],
|
112 |
+
'low_conf': [
|
113 |
+
"Có vẻ như bạn đang gặp khó khăn. Tôi có hiểu đúng không?",
|
114 |
+
"Tôi không chắc mình hiểu hết, bạn có thể giải thích thêm được không?",
|
115 |
+
"Hãy chia sẻ thêm để tôi có thể hiểu rõ hơn nhé."
|
116 |
+
]
|
117 |
+
}
|
118 |
}
|
119 |
+
|
120 |
+
# Add more classes based on your model's output
|
121 |
+
|
122 |
+
# Determine confidence level
|
123 |
+
confidence_threshold = 0.8
|
124 |
+
conf_level = 'high_conf' if confidence > confidence_threshold else 'low_conf'
|
125 |
|
126 |
+
# Get appropriate response list
|
127 |
+
try:
|
128 |
+
response_list = responses[predicted_class][conf_level]
|
129 |
+
response = random.choice(response_list)
|
130 |
+
except KeyError:
|
131 |
+
response = "Xin lỗi, tôi không chắc chắn về điều này. Bạn có thể giải thích rõ hơn được không?"
|
132 |
+
|
133 |
+
# Add context from user's input
|
134 |
+
context_response = f"{response}"
|
135 |
+
|
136 |
+
return context_response
|
137 |
|
138 |
def init_session_state():
|
139 |
if 'messages' not in st.session_state:
|
|
|
149 |
st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
|
150 |
|
151 |
def main():
|
152 |
+
st.set_page_config(page_title="Vietnamese Chatbot", page_icon="🤖", layout="wide")
|
153 |
local_css()
|
154 |
init_session_state()
|
155 |
|
|
|
157 |
model, tokenizer = load_model()
|
158 |
|
159 |
# Chat interface
|
160 |
+
st.title("Chatbot Tiếng Việt 🤖")
|
161 |
+
st.markdown("Xin chào! Tôi có thể giúp gì cho bạn?")
|
162 |
|
163 |
# Chat history container
|
164 |
chat_container = st.container()
|
165 |
|
166 |
+
# Input container
|
167 |
with st.container():
|
168 |
col1, col2 = st.columns([6, 1])
|
169 |
with col1:
|
170 |
+
user_input = st.text_input("Nhập tin nhắn của bạn...", key="user_input", label_visibility="hidden")
|
171 |
with col2:
|
172 |
+
send_button = st.button("Gửi")
|
173 |
|
174 |
if user_input and send_button:
|
175 |
+
# Add user message
|
176 |
st.session_state.messages.append({"role": "user", "content": user_input})
|
177 |
|
178 |
# Show thinking animation
|
179 |
st.session_state.thinking = True
|
180 |
|
181 |
+
# Get prediction
|
182 |
+
predicted_class, confidence = predict(user_input, model, tokenizer)
|
183 |
|
184 |
+
# Generate response
|
185 |
+
bot_response = get_bot_response(user_input, predicted_class, confidence)
|
186 |
|
187 |
+
# Add bot response
|
188 |
+
time.sleep(0.5) # Brief delay for natural feeling
|
189 |
st.session_state.messages.append({"role": "assistant", "content": bot_response})
|
190 |
st.session_state.thinking = False
|
191 |
|
192 |
+
# Clear input and rerun
|
193 |
st.rerun()
|
194 |
|
195 |
# Display chat history
|
196 |
with chat_container:
|
197 |
display_chat_history()
|
198 |
|
|
|
199 |
if st.session_state.thinking:
|
200 |
st.markdown("""
|
201 |
<div class="thinking-animation">
|