import streamlit as st from transformers import BertForSequenceClassification, BertTokenizer import torch import time import streamlit.components.v1 as components # Custom CSS for chat interface def local_css(): st.markdown(""" """, unsafe_allow_html=True) # Load model and tokenizer @st.cache_resource def load_model(): model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased") tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased") return model, tokenizer def predict(text, model, tokenizer): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(predictions, dim=1).item() return predicted_class, predictions[0] def get_bot_response(predicted_class, confidence): # Customize these responses based on your model's classes responses = { 0: "I understand this is about [Class 0]. Let me help you with that.", 1: "This seems to be related to [Class 1]. Here's what I can tell you.", # Add more responses for your classes } default_response = "I'm not quite sure about that. Could you please rephrase?" return responses.get(predicted_class, default_response) def init_session_state(): if 'messages' not in st.session_state: st.session_state.messages = [] if 'thinking' not in st.session_state: st.session_state.thinking = False def display_chat_history(): for message in st.session_state.messages: if message['role'] == 'user': st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) else: st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) def main(): st.set_page_config(page_title="AI Chatbot", page_icon="🤖", layout="wide") local_css() init_session_state() # Load model model, tokenizer = load_model() # Chat interface st.title("AI Chatbot 🤖") st.markdown("Welcome! I'm here to help answer your questions in Vietnamese.") # Chat history container chat_container = st.container() # Input container at the bottom with st.container(): col1, col2 = st.columns([6, 1]) with col1: user_input = st.text_input("Type your message...", key="user_input", label_visibility="hidden") with col2: send_button = st.button("Send") if user_input and send_button: # Add user message to chat st.session_state.messages.append({"role": "user", "content": user_input}) # Show thinking animation st.session_state.thinking = True # Get model prediction predicted_class, probabilities = predict(user_input, model, tokenizer) # Get bot response bot_response = get_bot_response(predicted_class, probabilities) # Add bot response to chat time.sleep(1) # Simulate processing time st.session_state.messages.append({"role": "assistant", "content": bot_response}) st.session_state.thinking = False # Clear input st.rerun() # Display chat history with chat_container: display_chat_history() # Show thinking animation if st.session_state.thinking: st.markdown("""
""", unsafe_allow_html=True) if __name__ == "__main__": main()