IOTraining / app.py
JustKiddo's picture
Update app.py
47ed12b verified
raw
history blame
5.55 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
# Custom CSS for the chat interface
def local_css():
st.markdown("""
<style>
.chat-container {
padding: 10px;
border-radius: 5px;
margin-bottom: 10px;
display: flex;
flex-direction: column;
}
.user-message {
background-color: #e3f2fd;
padding: 10px;
border-radius: 15px;
margin: 5px;
margin-left: 20%;
margin-right: 5px;
align-self: flex-end;
max-width: 70%;
}
.bot-message {
background-color: #f5f5f5;
padding: 10px;
border-radius: 15px;
margin: 5px;
margin-right: 20%;
margin-left: 5px;
align-self: flex-start;
max-width: 70%;
}
.thinking-animation {
display: flex;
align-items: center;
margin-left: 10px;
}
.dot {
width: 8px;
height: 8px;
margin: 0 3px;
background: #888;
border-radius: 50%;
animation: bounce 0.8s infinite;
}
.dot:nth-child(2) { animation-delay: 0.2s; }
.dot:nth-child(3) { animation-delay: 0.4s; }
@keyframes bounce {
0%, 100% { transform: translateY(0); }
50% { transform: translateY(-5px); }
}
</style>
""", unsafe_allow_html=True)
# Load model and tokenizer
@st.cache_resource
def load_model():
# Using VietAI's Vietnamese GPT model
model_name = "tamgrnguyen/Gemma-2-2b-it-Vietnamese-Aesthetic"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return model, tokenizer
def generate_response(prompt, model, tokenizer, max_length=100):
# Prepare input
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
# Generate response
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
top_k=50,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
attention_mask=inputs.attention_mask
)
# Decode response
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the input prompt from the response
response = response[len(prompt):].strip()
return response
def init_session_state():
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'thinking' not in st.session_state:
st.session_state.thinking = False
def display_chat_history():
for message in st.session_state.messages:
if message['role'] == 'user':
st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
else:
st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
def main():
st.set_page_config(
page_title="AI Chatbot Tiếng Việt",
page_icon="🤖",
layout="wide"
)
local_css()
init_session_state()
# Load model
model, tokenizer = load_model()
# Chat interface
st.title("AI Chatbot Tiếng Việt 🤖")
st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
# Chat history container
chat_container = st.container()
# Input container
with st.container():
col1, col2 = st.columns([6, 1])
with col1:
user_input = st.text_input(
"Nhập tin nhắn của bạn...",
key="user_input",
label_visibility="hidden"
)
with col2:
send_button = st.button("Gửi")
if user_input and send_button:
# Add user message
st.session_state.messages.append({"role": "user", "content": user_input})
# Show thinking animation
st.session_state.thinking = True
# Prepare conversation history
conversation_history = "\n".join([
f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
for msg in st.session_state.messages[-3:] # Last 3 messages for context
])
# Generate response
prompt = f"{conversation_history}\nAssistant:"
bot_response = generate_response(prompt, model, tokenizer)
# Add bot response
time.sleep(0.5) # Brief delay for natural feeling
st.session_state.messages.append({"role": "assistant", "content": bot_response})
st.session_state.thinking = False
# Clear input and rerun
st.rerun()
# Display chat history
with chat_container:
display_chat_history()
if st.session_state.thinking:
st.markdown("""
<div class="thinking-animation">
<div class="dot"></div>
<div class="dot"></div>
<div class="dot"></div>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()