JustKiddo commited on
Commit
985ad3e
1 Parent(s): 76bb75b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -178
app.py CHANGED
@@ -1,202 +1,111 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- import time
5
 
6
- # Custom CSS for the chat interface
7
- def local_css():
8
- st.markdown("""
9
- <style>
10
- .chat-container {
11
- padding: 10px;
12
- border-radius: 5px;
13
- margin-bottom: 10px;
14
- display: flex;
15
- flex-direction: column;
16
- }
17
-
18
- .user-message {
19
- background-color: #e3f2fd;
20
- padding: 10px;
21
- border-radius: 15px;
22
- margin: 5px;
23
- margin-left: 20%;
24
- margin-right: 5px;
25
- align-self: flex-end;
26
- max-width: 70%;
27
- }
28
-
29
- .bot-message {
30
- background-color: #f5f5f5;
31
- padding: 10px;
32
- border-radius: 15px;
33
- margin: 5px;
34
- margin-right: 20%;
35
- margin-left: 5px;
36
- align-self: flex-start;
37
- max-width: 70%;
38
- }
39
-
40
- .thinking-animation {
41
- display: flex;
42
- align-items: center;
43
- margin-left: 10px;
44
- }
45
 
46
- .dot {
47
- width: 8px;
48
- height: 8px;
49
- margin: 0 3px;
50
- background: #888;
51
- border-radius: 50%;
52
- animation: bounce 0.8s infinite;
53
- }
54
-
55
- .dot:nth-child(2) { animation-delay: 0.2s; }
56
- .dot:nth-child(3) { animation-delay: 0.4s; }
57
-
58
- @keyframes bounce {
59
- 0%, 100% { transform: translateY(0); }
60
- 50% { transform: translateY(-5px); }
61
- }
62
- </style>
63
- """, unsafe_allow_html=True)
64
 
65
- # Load model and tokenizer
66
- @st.cache_resource
67
- def load_model():
68
- # Using VietAI's Vietnamese GPT model
69
- model_name = "vietai/gpt-neo-1.3B-vietnamese-news"
70
- tokenizer = AutoTokenizer.from_pretrained(model_name)
71
- model = AutoModelForCausalLM.from_pretrained(model_name)
72
-
73
- # Fix the padding token issue
74
- tokenizer.pad_token = tokenizer.eos_token
75
- model.config.pad_token_id = model.config.eos_token_id
76
-
77
- return model, tokenizer
78
-
79
- def generate_response(prompt, model, tokenizer, max_length=100):
80
- try:
81
- # Prepare input
82
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
83
 
84
  # Generate response
85
- with torch.no_grad():
86
- outputs = model.generate(
87
- inputs.input_ids,
88
- max_length=max_length,
89
- num_return_sequences=1,
90
- temperature=0.7,
91
- top_k=50,
92
- top_p=0.95,
93
- do_sample=True,
94
- pad_token_id=tokenizer.pad_token_id,
95
- attention_mask=inputs.attention_mask
96
- )
97
 
98
- # Decode response
99
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
- # Remove the input prompt from the response
 
101
  response = response[len(prompt):].strip()
102
 
103
- # If response is empty, return a default message
104
- if not response:
105
- return "Xin lỗi, tôi không thể tạo câu trả lời. Bạn có thể hỏi lại không?"
106
-
107
  return response
108
-
109
- except Exception as e:
110
- st.error(f"Error generating response: {str(e)}")
111
- return "Xin lỗi, đã có lỗi xảy ra. Vui lòng thử lại."
112
-
113
- def init_session_state():
114
- if 'messages' not in st.session_state:
115
- st.session_state.messages = []
116
- if 'thinking' not in st.session_state:
117
- st.session_state.thinking = False
118
-
119
- def display_chat_history():
120
- for message in st.session_state.messages:
121
- if message['role'] == 'user':
122
- st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
123
- else:
124
- st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
125
 
126
  def main():
127
  st.set_page_config(
128
- page_title="AI Chatbot Tiếng Việt",
129
- page_icon="🤖",
130
  layout="wide"
131
  )
132
-
133
- local_css()
134
- init_session_state()
135
-
136
- try:
137
- # Load model
138
- model, tokenizer = load_model()
139
-
140
- # Chat interface
141
- st.title("AI Chatbot Tiếng Việt 🤖")
142
- st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
143
-
144
- # Chat history container
145
- chat_container = st.container()
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- # Input container
148
- with st.container():
149
- col1, col2 = st.columns([6, 1])
150
- with col1:
151
- user_input = st.text_input(
152
- "Nhập tin nhắn của bạn...",
153
- key="user_input",
154
- label_visibility="hidden"
155
  )
156
- with col2:
157
- send_button = st.button("Gửi")
158
-
159
- if user_input and send_button:
160
- # Add user message
161
- st.session_state.messages.append({"role": "user", "content": user_input})
162
-
163
- # Show thinking animation
164
- st.session_state.thinking = True
165
 
166
- # Prepare conversation history
167
- conversation_history = "\n".join([
168
- f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
169
- for msg in st.session_state.messages[-3:] # Last 3 messages for context
170
- ])
171
-
172
- # Generate response
173
- prompt = f"{conversation_history}\nAssistant:"
174
- bot_response = generate_response(prompt, model, tokenizer)
175
-
176
- # Add bot response
177
- time.sleep(0.5) # Brief delay for natural feeling
178
- st.session_state.messages.append({"role": "assistant", "content": bot_response})
179
- st.session_state.thinking = False
180
-
181
- # Clear input and rerun
182
- st.rerun()
183
 
184
- # Display chat history
185
- with chat_container:
186
- display_chat_history()
187
-
188
- if st.session_state.thinking:
189
- st.markdown("""
190
- <div class="thinking-animation">
191
- <div class="dot"></div>
192
- <div class="dot"></div>
193
- <div class="dot"></div>
194
- </div>
195
- """, unsafe_allow_html=True)
196
-
197
- except Exception as e:
198
- st.error(f"An error occurred: {str(e)}")
199
- st.info("Please refresh the page to try again.")
 
 
 
200
 
201
  if __name__ == "__main__":
202
  main()
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
 
4
 
5
+ class VietnameseChatbot:
6
+ def __init__(self, model_name="vinai/gpt-neo-1.3B-vietnamese-news"):
7
+ """
8
+ Initialize the Vietnamese chatbot with a pre-trained model
9
+ """
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ self.model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Use GPU if available
14
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def generate_response(self, conversation_history, max_length=200):
18
+ """
19
+ Generate a response based on conversation history
20
+ """
21
+ # Combine conversation history into a single prompt
22
+ prompt = "\n".join(conversation_history)
23
+
24
+ # Tokenize input
25
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
 
26
 
27
  # Generate response
28
+ outputs = self.model.generate(
29
+ inputs.input_ids,
30
+ max_length=max_length,
31
+ num_return_sequences=1,
32
+ no_repeat_ngram_size=2,
33
+ temperature=0.7,
34
+ top_k=50,
35
+ top_p=0.95
36
+ )
 
 
 
37
 
38
+ # Decode the generated response
39
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+
41
+ # Extract only the new part of the response
42
  response = response[len(prompt):].strip()
43
 
 
 
 
 
44
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def main():
47
  st.set_page_config(
48
+ page_title="IOGPT",
49
+ page_icon="🇻🇳",
50
  layout="wide"
51
  )
52
+
53
+ st.title("🇻🇳 Chat với IOGPT")
54
+ st.markdown("""
55
+ ### Trò chuyện với IOGPT
56
+ """)
57
+
58
+ # Initialize chatbot
59
+ if 'chatbot' not in st.session_state:
60
+ try:
61
+ st.session_state.chatbot = VietnameseChatbot()
62
+ except Exception as e:
63
+ st.error(f"Lỗi khởi tạo mô hình: {e}")
64
+ return
65
+
66
+ # Initialize conversation history
67
+ if 'conversation' not in st.session_state:
68
+ st.session_state.conversation = []
69
+
70
+ # Chat interface
71
+ with st.form(key='chat_form'):
72
+ user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...")
73
+ send_button = st.form_submit_button("Gửi")
74
+
75
+ # Process user input
76
+ if send_button and user_input:
77
+ # Add user message to conversation
78
+ st.session_state.conversation.append(f"User: {user_input}")
79
 
80
+ try:
81
+ # Generate AI response
82
+ with st.spinner('Đang suy nghĩ...'):
83
+ ai_response = st.session_state.chatbot.generate_response(
84
+ st.session_state.conversation
 
 
 
85
  )
 
 
 
 
 
 
 
 
 
86
 
87
+ # Add AI response to conversation
88
+ st.session_state.conversation.append(f"AI: {ai_response}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ except Exception as e:
91
+ st.error(f"Lỗi trong quá trình trò chuyện: {e}")
92
+
93
+ # Display conversation history
94
+ st.subheader("Lịch sử trò chuyện")
95
+ for msg in st.session_state.conversation:
96
+ if msg.startswith("User:"):
97
+ st.markdown(f"**{msg}**")
98
+ else:
99
+ st.markdown(f"*{msg}*")
100
+
101
+ # Model and usage information
102
+ st.sidebar.header("Thông tin mô hình")
103
+ st.sidebar.info("""
104
+ ### hình AI Tiếng Việt
105
+ - Được huấn luyện trên dữ liệu tiếng Việt
106
+ - Hỗ trợ trò chuyện đa dạng
107
+ - Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình
108
+ """)
109
 
110
  if __name__ == "__main__":
111
  main()