JustKiddo commited on
Commit
76bb75b
1 Parent(s): cb4d8f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -76
app.py CHANGED
@@ -69,31 +69,46 @@ def load_model():
69
  model_name = "vietai/gpt-neo-1.3B-vietnamese-news"
70
  tokenizer = AutoTokenizer.from_pretrained(model_name)
71
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
72
  return model, tokenizer
73
 
74
  def generate_response(prompt, model, tokenizer, max_length=100):
75
- # Prepare input
76
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
77
-
78
- # Generate response
79
- with torch.no_grad():
80
- outputs = model.generate(
81
- inputs.input_ids,
82
- max_length=max_length,
83
- num_return_sequences=1,
84
- temperature=0.7,
85
- top_k=50,
86
- top_p=0.95,
87
- do_sample=True,
88
- pad_token_id=tokenizer.eos_token_id,
89
- attention_mask=inputs.attention_mask
90
- )
91
-
92
- # Decode response
93
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
- # Remove the input prompt from the response
95
- response = response[len(prompt):].strip()
96
- return response
 
 
 
 
 
 
 
 
 
 
97
 
98
  def init_session_state():
99
  if 'messages' not in st.session_state:
@@ -118,65 +133,70 @@ def main():
118
  local_css()
119
  init_session_state()
120
 
121
- # Load model
122
- model, tokenizer = load_model()
123
-
124
- # Chat interface
125
- st.title("AI Chatbot Tiếng Việt 🤖")
126
- st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
127
-
128
- # Chat history container
129
- chat_container = st.container()
130
-
131
- # Input container
132
- with st.container():
133
- col1, col2 = st.columns([6, 1])
134
- with col1:
135
- user_input = st.text_input(
136
- "Nhập tin nhắn của bạn...",
137
- key="user_input",
138
- label_visibility="hidden"
139
- )
140
- with col2:
141
- send_button = st.button("Gửi")
142
-
143
- if user_input and send_button:
144
- # Add user message
145
- st.session_state.messages.append({"role": "user", "content": user_input})
146
 
147
- # Show thinking animation
148
- st.session_state.thinking = True
 
149
 
150
- # Prepare conversation history
151
- conversation_history = "\n".join([
152
- f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
153
- for msg in st.session_state.messages[-3:] # Last 3 messages for context
154
- ])
155
 
156
- # Generate response
157
- prompt = f"{conversation_history}\nAssistant:"
158
- bot_response = generate_response(prompt, model, tokenizer)
159
-
160
- # Add bot response
161
- time.sleep(0.5) # Brief delay for natural feeling
162
- st.session_state.messages.append({"role": "assistant", "content": bot_response})
163
- st.session_state.thinking = False
 
 
 
164
 
165
- # Clear input and rerun
166
- st.rerun()
167
-
168
- # Display chat history
169
- with chat_container:
170
- display_chat_history()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- if st.session_state.thinking:
173
- st.markdown("""
174
- <div class="thinking-animation">
175
- <div class="dot"></div>
176
- <div class="dot"></div>
177
- <div class="dot"></div>
178
- </div>
179
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
180
 
181
  if __name__ == "__main__":
182
  main()
 
69
  model_name = "vietai/gpt-neo-1.3B-vietnamese-news"
70
  tokenizer = AutoTokenizer.from_pretrained(model_name)
71
  model = AutoModelForCausalLM.from_pretrained(model_name)
72
+
73
+ # Fix the padding token issue
74
+ tokenizer.pad_token = tokenizer.eos_token
75
+ model.config.pad_token_id = model.config.eos_token_id
76
+
77
  return model, tokenizer
78
 
79
  def generate_response(prompt, model, tokenizer, max_length=100):
80
+ try:
81
+ # Prepare input
82
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
83
+
84
+ # Generate response
85
+ with torch.no_grad():
86
+ outputs = model.generate(
87
+ inputs.input_ids,
88
+ max_length=max_length,
89
+ num_return_sequences=1,
90
+ temperature=0.7,
91
+ top_k=50,
92
+ top_p=0.95,
93
+ do_sample=True,
94
+ pad_token_id=tokenizer.pad_token_id,
95
+ attention_mask=inputs.attention_mask
96
+ )
97
+
98
+ # Decode response
99
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
+ # Remove the input prompt from the response
101
+ response = response[len(prompt):].strip()
102
+
103
+ # If response is empty, return a default message
104
+ if not response:
105
+ return "Xin lỗi, tôi không thể tạo câu trả lời. Bạn có thể hỏi lại không?"
106
+
107
+ return response
108
+
109
+ except Exception as e:
110
+ st.error(f"Error generating response: {str(e)}")
111
+ return "Xin lỗi, đã có lỗi xảy ra. Vui lòng thử lại."
112
 
113
  def init_session_state():
114
  if 'messages' not in st.session_state:
 
133
  local_css()
134
  init_session_state()
135
 
136
+ try:
137
+ # Load model
138
+ model, tokenizer = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ # Chat interface
141
+ st.title("AI Chatbot Tiếng Việt 🤖")
142
+ st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
143
 
144
+ # Chat history container
145
+ chat_container = st.container()
 
 
 
146
 
147
+ # Input container
148
+ with st.container():
149
+ col1, col2 = st.columns([6, 1])
150
+ with col1:
151
+ user_input = st.text_input(
152
+ "Nhập tin nhắn của bạn...",
153
+ key="user_input",
154
+ label_visibility="hidden"
155
+ )
156
+ with col2:
157
+ send_button = st.button("Gửi")
158
 
159
+ if user_input and send_button:
160
+ # Add user message
161
+ st.session_state.messages.append({"role": "user", "content": user_input})
162
+
163
+ # Show thinking animation
164
+ st.session_state.thinking = True
165
+
166
+ # Prepare conversation history
167
+ conversation_history = "\n".join([
168
+ f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
169
+ for msg in st.session_state.messages[-3:] # Last 3 messages for context
170
+ ])
171
+
172
+ # Generate response
173
+ prompt = f"{conversation_history}\nAssistant:"
174
+ bot_response = generate_response(prompt, model, tokenizer)
175
+
176
+ # Add bot response
177
+ time.sleep(0.5) # Brief delay for natural feeling
178
+ st.session_state.messages.append({"role": "assistant", "content": bot_response})
179
+ st.session_state.thinking = False
180
+
181
+ # Clear input and rerun
182
+ st.rerun()
183
 
184
+ # Display chat history
185
+ with chat_container:
186
+ display_chat_history()
187
+
188
+ if st.session_state.thinking:
189
+ st.markdown("""
190
+ <div class="thinking-animation">
191
+ <div class="dot"></div>
192
+ <div class="dot"></div>
193
+ <div class="dot"></div>
194
+ </div>
195
+ """, unsafe_allow_html=True)
196
+
197
+ except Exception as e:
198
+ st.error(f"An error occurred: {str(e)}")
199
+ st.info("Please refresh the page to try again.")
200
 
201
  if __name__ == "__main__":
202
  main()