JustKiddo commited on
Commit
a4c3bcc
β€’
1 Parent(s): 025e654

New BERT integration with new Chat UI

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import BertForSequenceClassification, BertTokenizer
3
+ import torch
4
+ import time
5
+ import streamlit.components.v1 as components
6
+
7
+ # Custom CSS for chat interface
8
+ def local_css():
9
+ st.markdown("""
10
+ <style>
11
+ .chat-container {
12
+ padding: 10px;
13
+ border-radius: 5px;
14
+ margin-bottom: 10px;
15
+ display: flex;
16
+ flex-direction: column;
17
+ }
18
+
19
+ .user-message {
20
+ background-color: #e3f2fd;
21
+ padding: 10px;
22
+ border-radius: 15px;
23
+ margin: 5px;
24
+ margin-left: 20%;
25
+ margin-right: 5px;
26
+ align-self: flex-end;
27
+ max-width: 70%;
28
+ }
29
+
30
+ .bot-message {
31
+ background-color: #f5f5f5;
32
+ padding: 10px;
33
+ border-radius: 15px;
34
+ margin: 5px;
35
+ margin-right: 20%;
36
+ margin-left: 5px;
37
+ align-self: flex-start;
38
+ max-width: 70%;
39
+ }
40
+
41
+ .chat-input {
42
+ position: fixed;
43
+ bottom: 0;
44
+ width: 100%;
45
+ padding: 20px;
46
+ background-color: white;
47
+ }
48
+
49
+ .thinking-animation {
50
+ display: flex;
51
+ align-items: center;
52
+ margin-left: 10px;
53
+ }
54
+
55
+ .dot {
56
+ width: 8px;
57
+ height: 8px;
58
+ margin: 0 3px;
59
+ background: #888;
60
+ border-radius: 50%;
61
+ animation: bounce 0.8s infinite;
62
+ }
63
+
64
+ .dot:nth-child(2) { animation-delay: 0.2s; }
65
+ .dot:nth-child(3) { animation-delay: 0.4s; }
66
+
67
+ @keyframes bounce {
68
+ 0%, 100% { transform: translateY(0); }
69
+ 50% { transform: translateY(-5px); }
70
+ }
71
+ </style>
72
+ """, unsafe_allow_html=True)
73
+
74
+ # Load model and tokenizer
75
+ @st.cache_resource
76
+ def load_model():
77
+ model = BertForSequenceClassification.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
78
+ tokenizer = BertTokenizer.from_pretrained("trituenhantaoio/bert-base-vietnamese-uncased")
79
+ return model, tokenizer
80
+
81
+ def predict(text, model, tokenizer):
82
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
83
+
84
+ with torch.no_grad():
85
+ outputs = model(**inputs)
86
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
87
+ predicted_class = torch.argmax(predictions, dim=1).item()
88
+
89
+ return predicted_class, predictions[0]
90
+
91
+ def get_bot_response(predicted_class, confidence):
92
+ # Customize these responses based on your model's classes
93
+ responses = {
94
+ 0: "I understand this is about [Class 0]. Let me help you with that.",
95
+ 1: "This seems to be related to [Class 1]. Here's what I can tell you.",
96
+ # Add more responses for your classes
97
+ }
98
+
99
+ default_response = "I'm not quite sure about that. Could you please rephrase?"
100
+ return responses.get(predicted_class, default_response)
101
+
102
+ def init_session_state():
103
+ if 'messages' not in st.session_state:
104
+ st.session_state.messages = []
105
+ if 'thinking' not in st.session_state:
106
+ st.session_state.thinking = False
107
+
108
+ def display_chat_history():
109
+ for message in st.session_state.messages:
110
+ if message['role'] == 'user':
111
+ st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
112
+ else:
113
+ st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
114
+
115
+ def main():
116
+ st.set_page_config(page_title="AI Chatbot", page_icon="πŸ€–", layout="wide")
117
+ local_css()
118
+ init_session_state()
119
+
120
+ # Load model
121
+ model, tokenizer = load_model()
122
+
123
+ # Chat interface
124
+ st.title("AI Chatbot πŸ€–")
125
+ st.markdown("Welcome! I'm here to help answer your questions in Vietnamese.")
126
+
127
+ # Chat history container
128
+ chat_container = st.container()
129
+
130
+ # Input container at the bottom
131
+ with st.container():
132
+ col1, col2 = st.columns([6, 1])
133
+ with col1:
134
+ user_input = st.text_input("Type your message...", key="user_input", label_visibility="hidden")
135
+ with col2:
136
+ send_button = st.button("Send")
137
+
138
+ if user_input and send_button:
139
+ # Add user message to chat
140
+ st.session_state.messages.append({"role": "user", "content": user_input})
141
+
142
+ # Show thinking animation
143
+ st.session_state.thinking = True
144
+
145
+ # Get model prediction
146
+ predicted_class, probabilities = predict(user_input, model, tokenizer)
147
+
148
+ # Get bot response
149
+ bot_response = get_bot_response(predicted_class, probabilities)
150
+
151
+ # Add bot response to chat
152
+ time.sleep(1) # Simulate processing time
153
+ st.session_state.messages.append({"role": "assistant", "content": bot_response})
154
+ st.session_state.thinking = False
155
+
156
+ # Clear input
157
+ st.rerun()
158
+
159
+ # Display chat history
160
+ with chat_container:
161
+ display_chat_history()
162
+
163
+ # Show thinking animation
164
+ if st.session_state.thinking:
165
+ st.markdown("""
166
+ <div class="thinking-animation">
167
+ <div class="dot"></div>
168
+ <div class="dot"></div>
169
+ <div class="dot"></div>
170
+ </div>
171
+ """, unsafe_allow_html=True)
172
+
173
+ if __name__ == "__main__":
174
+ main()