JustKiddo commited on
Commit
47ed12b
1 Parent(s): 1934f9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -89
app.py CHANGED
@@ -1,109 +1,182 @@
1
  import streamlit as st
2
- import os
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
4
- from huggingface_hub import login
5
  import torch
 
6
 
7
- class VietnameseChatbot:
8
- def __init__(self, model_name="tamgrnguyen/Gemma-2-2b-it-Vietnamese-Aesthetic"):
9
- """
10
- Initialize the Vietnamese chatbot with a pre-trained model
11
- """
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- self.model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
14
 
15
- # Use GPU if available
16
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
17
- self.model.to(self.device)
18
-
19
- def generate_response(self, conversation_history, max_length=200):
20
- """
21
- Generate a response based on conversation history
22
- """
23
- # Combine conversation history into a single prompt
24
- prompt = "\n".join(conversation_history)
25
 
26
- # Tokenize input
27
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
 
 
 
 
 
 
 
 
28
 
29
- # Generate response
30
- outputs = self.model.generate(
31
- inputs.input_ids,
32
- max_length=max_length,
33
- num_return_sequences=1,
34
- no_repeat_ngram_size=2,
35
- temperature=0.7,
36
- top_k=50,
37
- top_p=0.95
38
- )
39
 
40
- # Decode the generated response
41
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
42
 
43
- # Extract only the new part of the response
44
- response = response[len(prompt):].strip()
45
 
46
- return response
 
 
 
 
 
47
 
48
- def main():
49
- st.set_page_config(page_title="IOGPT", layout="wide")
 
 
 
 
 
 
50
 
51
- st.title("Chat với IOGPT")
52
- st.markdown("""
53
- ### Trò chuyện với IOGPT
54
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # Initialize chatbot
57
- if 'chatbot' not in st.session_state:
58
- try:
59
- st.session_state.chatbot = VietnameseChatbot()
60
- except Exception as e:
61
- st.error(f"Lỗi khởi tạo mô hình: {e}")
62
- return
63
 
64
- # Initialize conversation history
65
- if 'conversation' not in st.session_state:
66
- st.session_state.conversation = []
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  # Chat interface
69
- with st.form(key='chat_form'):
70
- user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...")
71
- send_button = st.form_submit_button("Gửi")
72
-
73
- # Process user input
74
- if send_button and user_input:
75
- # Add user message to conversation
76
- st.session_state.conversation.append(f"User: {user_input}")
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- try:
79
- # Generate AI response
80
- with st.spinner('Đang suy nghĩ...'):
81
- ai_response = st.session_state.chatbot.generate_response(
82
- st.session_state.conversation
83
- )
84
-
85
- # Add AI response to conversation
86
- st.session_state.conversation.append(f"AI: {ai_response}")
87
 
88
- except Exception as e:
89
- st.error(f"Lỗi trong quá trình trò chuyện: {e}")
90
-
91
- # Display conversation history
92
- st.subheader("Lịch sử trò chuyện")
93
- for msg in st.session_state.conversation:
94
- if msg.startswith("User:"):
95
- st.markdown(f"**{msg}**")
96
- else:
97
- st.markdown(f"*{msg}*")
98
-
99
- # Model and usage information
100
- st.sidebar.header("Thông tin hình")
101
- st.sidebar.info("""
102
- ### Mô hình AI Tiếng Việt
103
- - Được huấn luyện trên dữ liệu tiếng Việt
104
- - Hỗ trợ trò chuyện đa dạng
105
- - Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình
106
- """)
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  if __name__ == "__main__":
109
  main()
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
3
  import torch
4
+ import time
5
 
6
+ # Custom CSS for the chat interface
7
+ def local_css():
8
+ st.markdown("""
9
+ <style>
10
+ .chat-container {
11
+ padding: 10px;
12
+ border-radius: 5px;
13
+ margin-bottom: 10px;
14
+ display: flex;
15
+ flex-direction: column;
16
+ }
17
 
18
+ .user-message {
19
+ background-color: #e3f2fd;
20
+ padding: 10px;
21
+ border-radius: 15px;
22
+ margin: 5px;
23
+ margin-left: 20%;
24
+ margin-right: 5px;
25
+ align-self: flex-end;
26
+ max-width: 70%;
27
+ }
28
 
29
+ .bot-message {
30
+ background-color: #f5f5f5;
31
+ padding: 10px;
32
+ border-radius: 15px;
33
+ margin: 5px;
34
+ margin-right: 20%;
35
+ margin-left: 5px;
36
+ align-self: flex-start;
37
+ max-width: 70%;
38
+ }
39
 
40
+ .thinking-animation {
41
+ display: flex;
42
+ align-items: center;
43
+ margin-left: 10px;
44
+ }
 
 
 
 
 
45
 
46
+ .dot {
47
+ width: 8px;
48
+ height: 8px;
49
+ margin: 0 3px;
50
+ background: #888;
51
+ border-radius: 50%;
52
+ animation: bounce 0.8s infinite;
53
+ }
54
 
55
+ .dot:nth-child(2) { animation-delay: 0.2s; }
56
+ .dot:nth-child(3) { animation-delay: 0.4s; }
57
 
58
+ @keyframes bounce {
59
+ 0%, 100% { transform: translateY(0); }
60
+ 50% { transform: translateY(-5px); }
61
+ }
62
+ </style>
63
+ """, unsafe_allow_html=True)
64
 
65
+ # Load model and tokenizer
66
+ @st.cache_resource
67
+ def load_model():
68
+ # Using VietAI's Vietnamese GPT model
69
+ model_name = "tamgrnguyen/Gemma-2-2b-it-Vietnamese-Aesthetic"
70
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
71
+ model = AutoModelForCausalLM.from_pretrained(model_name)
72
+ return model, tokenizer
73
 
74
+ def generate_response(prompt, model, tokenizer, max_length=100):
75
+ # Prepare input
76
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
77
+
78
+ # Generate response
79
+ with torch.no_grad():
80
+ outputs = model.generate(
81
+ inputs.input_ids,
82
+ max_length=max_length,
83
+ num_return_sequences=1,
84
+ temperature=0.7,
85
+ top_k=50,
86
+ top_p=0.95,
87
+ do_sample=True,
88
+ pad_token_id=tokenizer.eos_token_id,
89
+ attention_mask=inputs.attention_mask
90
+ )
91
+
92
+ # Decode response
93
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
94
+ # Remove the input prompt from the response
95
+ response = response[len(prompt):].strip()
96
+ return response
97
 
98
+ def init_session_state():
99
+ if 'messages' not in st.session_state:
100
+ st.session_state.messages = []
101
+ if 'thinking' not in st.session_state:
102
+ st.session_state.thinking = False
 
 
103
 
104
+ def display_chat_history():
105
+ for message in st.session_state.messages:
106
+ if message['role'] == 'user':
107
+ st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
108
+ else:
109
+ st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)
110
 
111
+ def main():
112
+ st.set_page_config(
113
+ page_title="AI Chatbot Tiếng Việt",
114
+ page_icon="🤖",
115
+ layout="wide"
116
+ )
117
+
118
+ local_css()
119
+ init_session_state()
120
+
121
+ # Load model
122
+ model, tokenizer = load_model()
123
+
124
  # Chat interface
125
+ st.title("AI Chatbot Tiếng Việt 🤖")
126
+ st.markdown("Xin chào! Tôi trợ AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
127
+
128
+ # Chat history container
129
+ chat_container = st.container()
130
+
131
+ # Input container
132
+ with st.container():
133
+ col1, col2 = st.columns([6, 1])
134
+ with col1:
135
+ user_input = st.text_input(
136
+ "Nhập tin nhắn của bạn...",
137
+ key="user_input",
138
+ label_visibility="hidden"
139
+ )
140
+ with col2:
141
+ send_button = st.button("Gửi")
142
+
143
+ if user_input and send_button:
144
+ # Add user message
145
+ st.session_state.messages.append({"role": "user", "content": user_input})
146
 
147
+ # Show thinking animation
148
+ st.session_state.thinking = True
 
 
 
 
 
 
 
149
 
150
+ # Prepare conversation history
151
+ conversation_history = "\n".join([
152
+ f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
153
+ for msg in st.session_state.messages[-3:] # Last 3 messages for context
154
+ ])
155
+
156
+ # Generate response
157
+ prompt = f"{conversation_history}\nAssistant:"
158
+ bot_response = generate_response(prompt, model, tokenizer)
159
+
160
+ # Add bot response
161
+ time.sleep(0.5) # Brief delay for natural feeling
162
+ st.session_state.messages.append({"role": "assistant", "content": bot_response})
163
+ st.session_state.thinking = False
164
+
165
+ # Clear input and rerun
166
+ st.rerun()
167
+
168
+ # Display chat history
169
+ with chat_container:
170
+ display_chat_history()
171
+
172
+ if st.session_state.thinking:
173
+ st.markdown("""
174
+ <div class="thinking-animation">
175
+ <div class="dot"></div>
176
+ <div class="dot"></div>
177
+ <div class="dot"></div>
178
+ </div>
179
+ """, unsafe_allow_html=True)
180
 
181
  if __name__ == "__main__":
182
  main()