Spaces:
Running
Running
import streamlit as st | |
from transformers import BertForSequenceClassification, BertTokenizer | |
import torch | |
import time | |
import streamlit.components.v1 as components | |
# Custom CSS for chat interface | |
def local_css(): | |
st.markdown(""" | |
<style> | |
.chat-container { | |
padding: 10px; | |
border-radius: 5px; | |
margin-bottom: 10px; | |
display: flex; | |
flex-direction: column; | |
} | |
.user-message { | |
background-color: #e3f2fd; | |
padding: 10px; | |
border-radius: 15px; | |
margin: 5px; | |
margin-left: 20%; | |
margin-right: 5px; | |
align-self: flex-end; | |
max-width: 70%; | |
} | |
.bot-message { | |
background-color: #f5f5f5; | |
padding: 10px; | |
border-radius: 15px; | |
margin: 5px; | |
margin-right: 20%; | |
margin-left: 5px; | |
align-self: flex-start; | |
max-width: 70%; | |
} | |
.chat-input { | |
position: fixed; | |
bottom: 0; | |
width: 100%; | |
padding: 20px; | |
background-color: white; | |
} | |
.thinking-animation { | |
display: flex; | |
align-items: center; | |
margin-left: 10px; | |
} | |
.dot { | |
width: 8px; | |
height: 8px; | |
margin: 0 3px; | |
background: #888; | |
border-radius: 50%; | |
animation: bounce 0.8s infinite; | |
} | |
.dot:nth-child(2) { animation-delay: 0.2s; } | |
.dot:nth-child(3) { animation-delay: 0.4s; } | |
@keyframes bounce { | |
0%, 100% { transform: translateY(0); } | |
50% { transform: translateY(-5px); } | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Load model and tokenizer | |
def load_model(): | |
model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased") | |
tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased") | |
return model, tokenizer | |
def predict(text, model, tokenizer): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
predicted_class = torch.argmax(predictions, dim=1).item() | |
return predicted_class, predictions[0] | |
def get_bot_response(predicted_class, confidence): | |
# Customize these responses based on your model's classes | |
responses = { | |
0: "I understand this is about [Class 0]. Let me help you with that.", | |
1: "This seems to be related to [Class 1]. Here's what I can tell you.", | |
# Add more responses for your classes | |
} | |
default_response = "I'm not quite sure about that. Could you please rephrase?" | |
return responses.get(predicted_class, default_response) | |
def init_session_state(): | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
if 'thinking' not in st.session_state: | |
st.session_state.thinking = False | |
def display_chat_history(): | |
for message in st.session_state.messages: | |
if message['role'] == 'user': | |
st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True) | |
else: | |
st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True) | |
def main(): | |
st.set_page_config(page_title="AI Chatbot", page_icon="π€", layout="wide") | |
local_css() | |
init_session_state() | |
# Load model | |
model, tokenizer = load_model() | |
# Chat interface | |
st.title("AI Chatbot π€") | |
st.markdown("Welcome! I'm here to help answer your questions in Vietnamese.") | |
# Chat history container | |
chat_container = st.container() | |
# Input container at the bottom | |
with st.container(): | |
col1, col2 = st.columns([6, 1]) | |
with col1: | |
user_input = st.text_input("Type your message...", key="user_input", label_visibility="hidden") | |
with col2: | |
send_button = st.button("Send") | |
if user_input and send_button: | |
# Add user message to chat | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
# Show thinking animation | |
st.session_state.thinking = True | |
# Get model prediction | |
predicted_class, probabilities = predict(user_input, model, tokenizer) | |
# Get bot response | |
bot_response = get_bot_response(predicted_class, probabilities) | |
# Add bot response to chat | |
time.sleep(1) # Simulate processing time | |
st.session_state.messages.append({"role": "assistant", "content": bot_response}) | |
st.session_state.thinking = False | |
# Clear input | |
st.rerun() | |
# Display chat history | |
with chat_container: | |
display_chat_history() | |
# Show thinking animation | |
if st.session_state.thinking: | |
st.markdown(""" | |
<div class="thinking-animation"> | |
<div class="dot"></div> | |
<div class="dot"></div> | |
<div class="dot"></div> | |
</div> | |
""", unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() |