File size: 6,562 Bytes
a4c3bcc
cb4d8f6
a4c3bcc
 
 
cb4d8f6
a4c3bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb4d8f6
a4c3bcc
 
cb4d8f6
 
 
 
76bb75b
 
 
 
 
a4c3bcc
 
cb4d8f6
76bb75b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb4d8f6
 
 
 
 
 
a4c3bcc
 
 
76bb75b
 
 
a4c3bcc
76bb75b
 
 
a4c3bcc
76bb75b
 
a4c3bcc
76bb75b
 
 
 
 
 
 
 
 
 
 
a4c3bcc
76bb75b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
76bb75b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

# Custom CSS for the chat interface
def local_css():
    st.markdown("""
        <style>
        .chat-container {
            padding: 10px;
            border-radius: 5px;
            margin-bottom: 10px;
            display: flex;
            flex-direction: column;
        }
        
        .user-message {
            background-color: #e3f2fd;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-left: 20%;
            margin-right: 5px;
            align-self: flex-end;
            max-width: 70%;
        }
        
        .bot-message {
            background-color: #f5f5f5;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-right: 20%;
            margin-left: 5px;
            align-self: flex-start;
            max-width: 70%;
        }
        
        .thinking-animation {
            display: flex;
            align-items: center;
            margin-left: 10px;
        }
        
        .dot {
            width: 8px;
            height: 8px;
            margin: 0 3px;
            background: #888;
            border-radius: 50%;
            animation: bounce 0.8s infinite;
        }
        
        .dot:nth-child(2) { animation-delay: 0.2s; }
        .dot:nth-child(3) { animation-delay: 0.4s; }
        
        @keyframes bounce {
            0%, 100% { transform: translateY(0); }
            50% { transform: translateY(-5px); }
        }
        </style>
    """, unsafe_allow_html=True)

# Load model and tokenizer
@st.cache_resource
def load_model():
    # Using VietAI's Vietnamese GPT model
    model_name = "vietai/gpt-neo-1.3B-vietnamese-news"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # Fix the padding token issue
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id
    
    return model, tokenizer

def generate_response(prompt, model, tokenizer, max_length=100):
    try:
        # Prepare input
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
        
        # Generate response
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.7,
                top_k=50,
                top_p=0.95,
                do_sample=True,
                pad_token_id=tokenizer.pad_token_id,
                attention_mask=inputs.attention_mask
            )
        
        # Decode response
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Remove the input prompt from the response
        response = response[len(prompt):].strip()
        
        # If response is empty, return a default message
        if not response:
            return "Xin lỗi, tôi không thể tạo câu trả lời. Bạn có thể hỏi lại không?"
            
        return response
        
    except Exception as e:
        st.error(f"Error generating response: {str(e)}")
        return "Xin lỗi, đã có lỗi xảy ra. Vui lòng thử lại."

def init_session_state():
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    if 'thinking' not in st.session_state:
        st.session_state.thinking = False

def display_chat_history():
    for message in st.session_state.messages:
        if message['role'] == 'user':
            st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)

def main():
    st.set_page_config(
        page_title="AI Chatbot Tiếng Việt",
        page_icon="🤖",
        layout="wide"
    )
    
    local_css()
    init_session_state()
    
    try:
        # Load model
        model, tokenizer = load_model()
        
        # Chat interface
        st.title("AI Chatbot Tiếng Việt 🤖")
        st.markdown("Xin chào! Tôi là trợ lý AI có thể trò chuyện bằng tiếng Việt. Hãy hỏi tôi bất cứ điều gì!")
        
        # Chat history container
        chat_container = st.container()
        
        # Input container
        with st.container():
            col1, col2 = st.columns([6, 1])
            with col1:
                user_input = st.text_input(
                    "Nhập tin nhắn của bạn...",
                    key="user_input",
                    label_visibility="hidden"
                )
            with col2:
                send_button = st.button("Gửi")
        
        if user_input and send_button:
            # Add user message
            st.session_state.messages.append({"role": "user", "content": user_input})
            
            # Show thinking animation
            st.session_state.thinking = True
            
            # Prepare conversation history
            conversation_history = "\n".join([
                f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
                for msg in st.session_state.messages[-3:]  # Last 3 messages for context
            ])
            
            # Generate response
            prompt = f"{conversation_history}\nAssistant:"
            bot_response = generate_response(prompt, model, tokenizer)
            
            # Add bot response
            time.sleep(0.5)  # Brief delay for natural feeling
            st.session_state.messages.append({"role": "assistant", "content": bot_response})
            st.session_state.thinking = False
            
            # Clear input and rerun
            st.rerun()
        
        # Display chat history
        with chat_container:
            display_chat_history()
            
            if st.session_state.thinking:
                st.markdown("""
                    <div class="thinking-animation">
                        <div class="dot"></div>
                        <div class="dot"></div>
                        <div class="dot"></div>
                    </div>
                """, unsafe_allow_html=True)
                
    except Exception as e:
        st.error(f"An error occurred: {str(e)}")
        st.info("Please refresh the page to try again.")

if __name__ == "__main__":
    main()