import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time # Custom CSS for the chat interface def local_css(): st.markdown(""" """, unsafe_allow_html=True) # Load model and tokenizer @st.cache_resource def load_model(): # Using VietAI's Vietnamese GPT model model_name = "tamgrnguyen/Gemma-2-2b-it-Vietnamese-Aesthetic" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) return model, tokenizer def generate_response(prompt, model, tokenizer, max_length=100): # Prepare input inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) # Generate response with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=max_length, num_return_sequences=1, temperature=0.7, top_k=50, top_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id, attention_mask=inputs.attention_mask ) # Decode response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove the input prompt from the response response = response[len(prompt):].strip() return response def init_session_state(): if 'messages' not in st.session_state: st.session_state.messages = [] if 'thinking' not in st.session_state: st.session_state.thinking = False def display_chat_history(): for message in st.session_state.messages: if message['role'] == 'user': st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) else: st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) def main(): st.set_page_config( page_title="AI Chatbot Tiếng Việt", page_icon="🤖", layout="wide" ) local_css() init_session_state() # Load model model, tokenizer = load_model() # Chat interface st.title("AI Chatbot Tiếng Việt 🤖") st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!") # Chat history container chat_container = st.container() # Input container with st.container(): col1, col2 = st.columns([6, 1]) with col1: user_input = st.text_input( "Nhập tin nhắn của bạn...", key="user_input", label_visibility="hidden" ) with col2: send_button = st.button("Gửi") if user_input and send_button: # Add user message st.session_state.messages.append({"role": "user", "content": user_input}) # Show thinking animation st.session_state.thinking = True # Prepare conversation history conversation_history = "\n".join([ f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}" for msg in st.session_state.messages[-3:] # Last 3 messages for context ]) # Generate response prompt = f"{conversation_history}\nAssistant:" bot_response = generate_response(prompt, model, tokenizer) # Add bot response time.sleep(0.5) # Brief delay for natural feeling st.session_state.messages.append({"role": "assistant", "content": bot_response}) st.session_state.thinking = False # Clear input and rerun st.rerun() # Display chat history with chat_container: display_chat_history() if st.session_state.thinking: st.markdown("""
""", unsafe_allow_html=True) if __name__ == "__main__": main()