File size: 5,497 Bytes
a4c3bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import streamlit as st
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import time
import streamlit.components.v1 as components

# Custom CSS for chat interface
def local_css():
    st.markdown("""
        <style>
        .chat-container {
            padding: 10px;
            border-radius: 5px;
            margin-bottom: 10px;
            display: flex;
            flex-direction: column;
        }
        
        .user-message {
            background-color: #e3f2fd;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-left: 20%;
            margin-right: 5px;
            align-self: flex-end;
            max-width: 70%;
        }
        
        .bot-message {
            background-color: #f5f5f5;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-right: 20%;
            margin-left: 5px;
            align-self: flex-start;
            max-width: 70%;
        }
        
        .chat-input {
            position: fixed;
            bottom: 0;
            width: 100%;
            padding: 20px;
            background-color: white;
        }
        
        .thinking-animation {
            display: flex;
            align-items: center;
            margin-left: 10px;
        }
        
        .dot {
            width: 8px;
            height: 8px;
            margin: 0 3px;
            background: #888;
            border-radius: 50%;
            animation: bounce 0.8s infinite;
        }
        
        .dot:nth-child(2) { animation-delay: 0.2s; }
        .dot:nth-child(3) { animation-delay: 0.4s; }
        
        @keyframes bounce {
            0%, 100% { transform: translateY(0); }
            50% { transform: translateY(-5px); }
        }
        </style>
    """, unsafe_allow_html=True)

# Load model and tokenizer
@st.cache_resource
def load_model():
    model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
    tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
    return model, tokenizer

def predict(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
        predicted_class = torch.argmax(predictions, dim=1).item()
    
    return predicted_class, predictions[0]

def get_bot_response(predicted_class, confidence):
    # Customize these responses based on your model's classes
    responses = {
        0: "I understand this is about [Class 0]. Let me help you with that.",
        1: "This seems to be related to [Class 1]. Here's what I can tell you.",
        # Add more responses for your classes
    }
    
    default_response = "I'm not quite sure about that. Could you please rephrase?"
    return responses.get(predicted_class, default_response)

def init_session_state():
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    if 'thinking' not in st.session_state:
        st.session_state.thinking = False

def display_chat_history():
    for message in st.session_state.messages:
        if message['role'] == 'user':
            st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)

def main():
    st.set_page_config(page_title="AI Chatbot", page_icon="πŸ€–", layout="wide")
    local_css()
    init_session_state()
    
    # Load model
    model, tokenizer = load_model()
    
    # Chat interface
    st.title("AI Chatbot πŸ€–")
    st.markdown("Welcome! I'm here to help answer your questions in Vietnamese.")
    
    # Chat history container
    chat_container = st.container()
    
    # Input container at the bottom
    with st.container():
        col1, col2 = st.columns([6, 1])
        with col1:
            user_input = st.text_input("Type your message...", key="user_input", label_visibility="hidden")
        with col2:
            send_button = st.button("Send")
    
    if user_input and send_button:
        # Add user message to chat
        st.session_state.messages.append({"role": "user", "content": user_input})
        
        # Show thinking animation
        st.session_state.thinking = True
        
        # Get model prediction
        predicted_class, probabilities = predict(user_input, model, tokenizer)
        
        # Get bot response
        bot_response = get_bot_response(predicted_class, probabilities)
        
        # Add bot response to chat
        time.sleep(1)  # Simulate processing time
        st.session_state.messages.append({"role": "assistant", "content": bot_response})
        st.session_state.thinking = False
        
        # Clear input
        st.rerun()
    
    # Display chat history
    with chat_container:
        display_chat_history()
        
        # Show thinking animation
        if st.session_state.thinking:
            st.markdown("""
                <div class="thinking-animation">
                    <div class="dot"></div>
                    <div class="dot"></div>
                    <div class="dot"></div>
                </div>
            """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()