Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import time | |
# Custom CSS for the chat interface | |
def local_css(): | |
st.markdown(""" | |
<style> | |
.chat-container { | |
padding: 10px; | |
border-radius: 5px; | |
margin-bottom: 10px; | |
display: flex; | |
flex-direction: column; | |
} | |
.user-message { | |
background-color: #e3f2fd; | |
padding: 10px; | |
border-radius: 15px; | |
margin: 5px; | |
margin-left: 20%; | |
margin-right: 5px; | |
align-self: flex-end; | |
max-width: 70%; | |
} | |
.bot-message { | |
background-color: #f5f5f5; | |
padding: 10px; | |
border-radius: 15px; | |
margin: 5px; | |
margin-right: 20%; | |
margin-left: 5px; | |
align-self: flex-start; | |
max-width: 70%; | |
} | |
.thinking-animation { | |
display: flex; | |
align-items: center; | |
margin-left: 10px; | |
} | |
.dot { | |
width: 8px; | |
height: 8px; | |
margin: 0 3px; | |
background: #888; | |
border-radius: 50%; | |
animation: bounce 0.8s infinite; | |
} | |
.dot:nth-child(2) { animation-delay: 0.2s; } | |
.dot:nth-child(3) { animation-delay: 0.4s; } | |
@keyframes bounce { | |
0%, 100% { transform: translateY(0); } | |
50% { transform: translateY(-5px); } | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Load model and tokenizer | |
def load_model(): | |
# Using VietAI's Vietnamese GPT model | |
model_name = "vietai/gpt-neo-1.3B-vietnamese-news" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Fix the padding token issue | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
return model, tokenizer | |
def generate_response(prompt, model, tokenizer, max_length=100): | |
try: | |
# Prepare input | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
attention_mask=inputs.attention_mask | |
) | |
# Decode response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove the input prompt from the response | |
response = response[len(prompt):].strip() | |
# If response is empty, return a default message | |
if not response: | |
return "Xin lỗi, tôi không thể tạo câu trả lời. Bạn có thể hỏi lại không?" | |
return response | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}") | |
return "Xin lỗi, đã có lỗi xảy ra. Vui lòng thử lại." | |
def init_session_state(): | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
if 'thinking' not in st.session_state: | |
st.session_state.thinking = False | |
def display_chat_history(): | |
for message in st.session_state.messages: | |
if message['role'] == 'user': | |
st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True) | |
def main(): | |
st.set_page_config( | |
page_title="AI Chatbot Tiếng Việt", | |
page_icon="🤖", | |
layout="wide" | |
) | |
local_css() | |
init_session_state() | |
try: | |
# Load model | |
model, tokenizer = load_model() | |
# Chat interface | |
st.title("AI Chatbot Tiếng Việt 🤖") | |
st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!") | |
# Chat history container | |
chat_container = st.container() | |
# Input container | |
with st.container(): | |
col1, col2 = st.columns([6, 1]) | |
with col1: | |
user_input = st.text_input( | |
"Nhập tin nhắn của bạn...", | |
key="user_input", | |
label_visibility="hidden" | |
) | |
with col2: | |
send_button = st.button("Gửi") | |
if user_input and send_button: | |
# Add user message | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
# Show thinking animation | |
st.session_state.thinking = True | |
# Prepare conversation history | |
conversation_history = "\n".join([ | |
f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}" | |
for msg in st.session_state.messages[-3:] # Last 3 messages for context | |
]) | |
# Generate response | |
prompt = f"{conversation_history}\nAssistant:" | |
bot_response = generate_response(prompt, model, tokenizer) | |
# Add bot response | |
time.sleep(0.5) # Brief delay for natural feeling | |
st.session_state.messages.append({"role": "assistant", "content": bot_response}) | |
st.session_state.thinking = False | |
# Clear input and rerun | |
st.rerun() | |
# Display chat history | |
with chat_container: | |
display_chat_history() | |
if st.session_state.thinking: | |
st.markdown(""" | |
<div class="thinking-animation"> | |
<div class="dot"></div> | |
<div class="dot"></div> | |
<div class="dot"></div> | |
</div> | |
""", unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
st.info("Please refresh the page to try again.") | |
if __name__ == "__main__": | |
main() |