JustKiddo commited on
Commit
cb4d8f6
·
verified ·
1 Parent(s): 2bc9c40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -74
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import streamlit as st
2
- from transformers import BertForSequenceClassification, BertTokenizer
3
  import torch
4
  import time
5
- import random
6
 
7
- # [Previous CSS styles remain the same]
8
  def local_css():
9
  st.markdown("""
10
  <style>
@@ -38,14 +37,6 @@ def local_css():
38
  max-width: 70%;
39
  }
40
 
41
- .chat-input {
42
- position: fixed;
43
- bottom: 0;
44
- width: 100%;
45
- padding: 20px;
46
- background-color: white;
47
- }
48
-
49
  .thinking-animation {
50
  display: flex;
51
  align-items: center;
@@ -71,69 +62,38 @@ def local_css():
71
  </style>
72
  """, unsafe_allow_html=True)
73
 
 
74
  @st.cache_resource
75
  def load_model():
76
- model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
77
- tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
 
 
78
  return model, tokenizer
79
 
80
- def predict(text, model, tokenizer):
81
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
82
 
 
83
  with torch.no_grad():
84
- outputs = model(**inputs)
85
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
86
- predicted_class = torch.argmax(predictions, dim=1).item()
87
- confidence = predictions[0][predicted_class].item()
88
-
89
- return predicted_class, confidence
90
-
91
- def get_bot_response(text, predicted_class, confidence):
92
- # Define response templates based on classes and confidence levels
93
- responses = {
94
- 0: { # Example for class 0 (positive sentiment)
95
- 'high_conf': [
96
- "Tôi cảm nhận được sự tích cực trong câu nói của bạn. Xin chia sẻ thêm nhé!",
97
- "Thật vui khi nghe điều đó. Bạn có thể kể thêm không?",
98
- "Tuyệt vời! Tôi rất đồng ý với bạn về điều này."
99
- ],
100
- 'low_conf': [
101
- "Có vẻ như đây là điều tích cực. Đúng không nhỉ?",
102
- "Tôi nghĩ đây là một góc nhìn thú vị đấy.",
103
- "Nghe có vẻ tốt đấy, bạn nghĩ sao?"
104
- ]
105
- },
106
- 1: { # Example for class 1 (negative sentiment)
107
- 'high_conf': [
108
- "Tôi hiểu đây là điều khó khăn với bạn. Hãy chia sẻ thêm nhé.",
109
- "Tôi rất tiếc khi nghe điều này. Bạn cần tôi giúp gì không?",
110
- "Đúng là một tình huống khó khăn. Chúng ta cùng tìm giải pháp nhé."
111
- ],
112
- 'low_conf': [
113
- "Có vẻ như bạn đang gặp khó khăn. Tôi có hiểu đúng không?",
114
- "Tôi không chắc mình hiểu hết, bạn có thể giải thích thêm được không?",
115
- "Hãy chia sẻ thêm để tôi có thể hiểu rõ hơn nhé."
116
- ]
117
- }
118
- }
119
-
120
- # Add more classes based on your model's output
121
-
122
- # Determine confidence level
123
- confidence_threshold = 0.8
124
- conf_level = 'high_conf' if confidence > confidence_threshold else 'low_conf'
125
 
126
- # Get appropriate response list
127
- try:
128
- response_list = responses[predicted_class][conf_level]
129
- response = random.choice(response_list)
130
- except KeyError:
131
- response = "Xin lỗi, tôi không chắc chắn về điều này. Bạn có thể giải thích rõ hơn được không?"
132
-
133
- # Add context from user's input
134
- context_response = f"{response}"
135
-
136
- return context_response
137
 
138
  def init_session_state():
139
  if 'messages' not in st.session_state:
@@ -149,7 +109,12 @@ def display_chat_history():
149
  st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
150
 
151
  def main():
152
- st.set_page_config(page_title="Vietnamese Chatbot", page_icon="🤖", layout="wide")
 
 
 
 
 
153
  local_css()
154
  init_session_state()
155
 
@@ -157,8 +122,8 @@ def main():
157
  model, tokenizer = load_model()
158
 
159
  # Chat interface
160
- st.title("Chatbot Tiếng Việt 🤖")
161
- st.markdown("Xin chào! Tôi có thể giúp cho bạn?")
162
 
163
  # Chat history container
164
  chat_container = st.container()
@@ -167,7 +132,11 @@ def main():
167
  with st.container():
168
  col1, col2 = st.columns([6, 1])
169
  with col1:
170
- user_input = st.text_input("Nhập tin nhắn của bạn...", key="user_input", label_visibility="hidden")
 
 
 
 
171
  with col2:
172
  send_button = st.button("Gửi")
173
 
@@ -178,11 +147,15 @@ def main():
178
  # Show thinking animation
179
  st.session_state.thinking = True
180
 
181
- # Get prediction
182
- predicted_class, confidence = predict(user_input, model, tokenizer)
 
 
 
183
 
184
  # Generate response
185
- bot_response = get_bot_response(user_input, predicted_class, confidence)
 
186
 
187
  # Add bot response
188
  time.sleep(0.5) # Brief delay for natural feeling
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import time
 
5
 
6
+ # Custom CSS for the chat interface
7
  def local_css():
8
  st.markdown("""
9
  <style>
 
37
  max-width: 70%;
38
  }
39
 
 
 
 
 
 
 
 
 
40
  .thinking-animation {
41
  display: flex;
42
  align-items: center;
 
62
  </style>
63
  """, unsafe_allow_html=True)
64
 
65
+ # Load model and tokenizer
66
  @st.cache_resource
67
  def load_model():
68
+ # Using VietAI's Vietnamese GPT model
69
+ model_name = "vietai/gpt-neo-1.3B-vietnamese-news"
70
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
71
+ model = AutoModelForCausalLM.from_pretrained(model_name)
72
  return model, tokenizer
73
 
74
+ def generate_response(prompt, model, tokenizer, max_length=100):
75
+ # Prepare input
76
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
77
 
78
+ # Generate response
79
  with torch.no_grad():
80
+ outputs = model.generate(
81
+ inputs.input_ids,
82
+ max_length=max_length,
83
+ num_return_sequences=1,
84
+ temperature=0.7,
85
+ top_k=50,
86
+ top_p=0.95,
87
+ do_sample=True,
88
+ pad_token_id=tokenizer.eos_token_id,
89
+ attention_mask=inputs.attention_mask
90
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Decode response
93
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
+ # Remove the input prompt from the response
95
+ response = response[len(prompt):].strip()
96
+ return response
 
 
 
 
 
 
97
 
98
  def init_session_state():
99
  if 'messages' not in st.session_state:
 
109
  st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
110
 
111
  def main():
112
+ st.set_page_config(
113
+ page_title="AI Chatbot Tiếng Việt",
114
+ page_icon="🤖",
115
+ layout="wide"
116
+ )
117
+
118
  local_css()
119
  init_session_state()
120
 
 
122
  model, tokenizer = load_model()
123
 
124
  # Chat interface
125
+ st.title("AI Chatbot Tiếng Việt 🤖")
126
+ st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
127
 
128
  # Chat history container
129
  chat_container = st.container()
 
132
  with st.container():
133
  col1, col2 = st.columns([6, 1])
134
  with col1:
135
+ user_input = st.text_input(
136
+ "Nhập tin nhắn của bạn...",
137
+ key="user_input",
138
+ label_visibility="hidden"
139
+ )
140
  with col2:
141
  send_button = st.button("Gửi")
142
 
 
147
  # Show thinking animation
148
  st.session_state.thinking = True
149
 
150
+ # Prepare conversation history
151
+ conversation_history = "\n".join([
152
+ f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
153
+ for msg in st.session_state.messages[-3:] # Last 3 messages for context
154
+ ])
155
 
156
  # Generate response
157
+ prompt = f"{conversation_history}\nAssistant:"
158
+ bot_response = generate_response(prompt, model, tokenizer)
159
 
160
  # Add bot response
161
  time.sleep(0.5) # Brief delay for natural feeling