Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -69,31 +69,46 @@ def load_model():
|
|
69 |
model_name = "vietai/gpt-neo-1.3B-vietnamese-news"
|
70 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
71 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
72 |
return model, tokenizer
|
73 |
|
74 |
def generate_response(prompt, model, tokenizer, max_length=100):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def init_session_state():
|
99 |
if 'messages' not in st.session_state:
|
@@ -118,65 +133,70 @@ def main():
|
|
118 |
local_css()
|
119 |
init_session_state()
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
# Chat interface
|
125 |
-
st.title("AI Chatbot Tiếng Việt 🤖")
|
126 |
-
st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
|
127 |
-
|
128 |
-
# Chat history container
|
129 |
-
chat_container = st.container()
|
130 |
-
|
131 |
-
# Input container
|
132 |
-
with st.container():
|
133 |
-
col1, col2 = st.columns([6, 1])
|
134 |
-
with col1:
|
135 |
-
user_input = st.text_input(
|
136 |
-
"Nhập tin nhắn của bạn...",
|
137 |
-
key="user_input",
|
138 |
-
label_visibility="hidden"
|
139 |
-
)
|
140 |
-
with col2:
|
141 |
-
send_button = st.button("Gửi")
|
142 |
-
|
143 |
-
if user_input and send_button:
|
144 |
-
# Add user message
|
145 |
-
st.session_state.messages.append({"role": "user", "content": user_input})
|
146 |
|
147 |
-
#
|
148 |
-
st.
|
|
|
149 |
|
150 |
-
#
|
151 |
-
|
152 |
-
f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
|
153 |
-
for msg in st.session_state.messages[-3:] # Last 3 messages for context
|
154 |
-
])
|
155 |
|
156 |
-
#
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
if __name__ == "__main__":
|
182 |
main()
|
|
|
69 |
model_name = "vietai/gpt-neo-1.3B-vietnamese-news"
|
70 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
71 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
72 |
+
|
73 |
+
# Fix the padding token issue
|
74 |
+
tokenizer.pad_token = tokenizer.eos_token
|
75 |
+
model.config.pad_token_id = model.config.eos_token_id
|
76 |
+
|
77 |
return model, tokenizer
|
78 |
|
79 |
def generate_response(prompt, model, tokenizer, max_length=100):
|
80 |
+
try:
|
81 |
+
# Prepare input
|
82 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
83 |
+
|
84 |
+
# Generate response
|
85 |
+
with torch.no_grad():
|
86 |
+
outputs = model.generate(
|
87 |
+
inputs.input_ids,
|
88 |
+
max_length=max_length,
|
89 |
+
num_return_sequences=1,
|
90 |
+
temperature=0.7,
|
91 |
+
top_k=50,
|
92 |
+
top_p=0.95,
|
93 |
+
do_sample=True,
|
94 |
+
pad_token_id=tokenizer.pad_token_id,
|
95 |
+
attention_mask=inputs.attention_mask
|
96 |
+
)
|
97 |
+
|
98 |
+
# Decode response
|
99 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
100 |
+
# Remove the input prompt from the response
|
101 |
+
response = response[len(prompt):].strip()
|
102 |
+
|
103 |
+
# If response is empty, return a default message
|
104 |
+
if not response:
|
105 |
+
return "Xin lỗi, tôi không thể tạo câu trả lời. Bạn có thể hỏi lại không?"
|
106 |
+
|
107 |
+
return response
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
st.error(f"Error generating response: {str(e)}")
|
111 |
+
return "Xin lỗi, đã có lỗi xảy ra. Vui lòng thử lại."
|
112 |
|
113 |
def init_session_state():
|
114 |
if 'messages' not in st.session_state:
|
|
|
133 |
local_css()
|
134 |
init_session_state()
|
135 |
|
136 |
+
try:
|
137 |
+
# Load model
|
138 |
+
model, tokenizer = load_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
+
# Chat interface
|
141 |
+
st.title("AI Chatbot Tiếng Việt 🤖")
|
142 |
+
st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
|
143 |
|
144 |
+
# Chat history container
|
145 |
+
chat_container = st.container()
|
|
|
|
|
|
|
146 |
|
147 |
+
# Input container
|
148 |
+
with st.container():
|
149 |
+
col1, col2 = st.columns([6, 1])
|
150 |
+
with col1:
|
151 |
+
user_input = st.text_input(
|
152 |
+
"Nhập tin nhắn của bạn...",
|
153 |
+
key="user_input",
|
154 |
+
label_visibility="hidden"
|
155 |
+
)
|
156 |
+
with col2:
|
157 |
+
send_button = st.button("Gửi")
|
158 |
|
159 |
+
if user_input and send_button:
|
160 |
+
# Add user message
|
161 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
162 |
+
|
163 |
+
# Show thinking animation
|
164 |
+
st.session_state.thinking = True
|
165 |
+
|
166 |
+
# Prepare conversation history
|
167 |
+
conversation_history = "\n".join([
|
168 |
+
f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
|
169 |
+
for msg in st.session_state.messages[-3:] # Last 3 messages for context
|
170 |
+
])
|
171 |
+
|
172 |
+
# Generate response
|
173 |
+
prompt = f"{conversation_history}\nAssistant:"
|
174 |
+
bot_response = generate_response(prompt, model, tokenizer)
|
175 |
+
|
176 |
+
# Add bot response
|
177 |
+
time.sleep(0.5) # Brief delay for natural feeling
|
178 |
+
st.session_state.messages.append({"role": "assistant", "content": bot_response})
|
179 |
+
st.session_state.thinking = False
|
180 |
+
|
181 |
+
# Clear input and rerun
|
182 |
+
st.rerun()
|
183 |
|
184 |
+
# Display chat history
|
185 |
+
with chat_container:
|
186 |
+
display_chat_history()
|
187 |
+
|
188 |
+
if st.session_state.thinking:
|
189 |
+
st.markdown("""
|
190 |
+
<div class="thinking-animation">
|
191 |
+
<div class="dot"></div>
|
192 |
+
<div class="dot"></div>
|
193 |
+
<div class="dot"></div>
|
194 |
+
</div>
|
195 |
+
""", unsafe_allow_html=True)
|
196 |
+
|
197 |
+
except Exception as e:
|
198 |
+
st.error(f"An error occurred: {str(e)}")
|
199 |
+
st.info("Please refresh the page to try again.")
|
200 |
|
201 |
if __name__ == "__main__":
|
202 |
main()
|