File size: 5,408 Bytes
a4c3bcc
47ed12b
a4c3bcc
47ed12b
a4c3bcc
47ed12b
 
 
 
 
 
 
 
 
 
 
a4c3bcc
47ed12b
 
 
 
 
 
 
 
 
 
985ad3e
47ed12b
 
 
 
 
 
 
 
 
 
76bb75b
47ed12b
 
 
 
 
76bb75b
47ed12b
 
 
 
 
 
 
 
985ad3e
47ed12b
 
76bb75b
47ed12b
 
 
 
 
 
a4c3bcc
47ed12b
 
 
 
 
 
 
 
985ad3e
47ed12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985ad3e
47ed12b
 
 
 
 
985ad3e
47ed12b
 
 
 
 
 
985ad3e
47ed12b
 
9e0b73b
47ed12b
 
 
 
 
 
 
 
 
 
985ad3e
9e0b73b
 
47ed12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
47ed12b
 
a4c3bcc
47ed12b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

# Custom CSS for the chat interface
def local_css():
    st.markdown("""
        <style>
        .chat-container {
            padding: 10px;
            border-radius: 5px;
            margin-bottom: 10px;
            display: flex;
            flex-direction: column;
        }
        
        .user-message {
            background-color: #e3f2fd;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-left: 20%;
            margin-right: 5px;
            align-self: flex-end;
            max-width: 70%;
        }
        
        .bot-message {
            background-color: #f5f5f5;
            padding: 10px;
            border-radius: 15px;
            margin: 5px;
            margin-right: 20%;
            margin-left: 5px;
            align-self: flex-start;
            max-width: 70%;
        }
        
        .thinking-animation {
            display: flex;
            align-items: center;
            margin-left: 10px;
        }
        
        .dot {
            width: 8px;
            height: 8px;
            margin: 0 3px;
            background: #888;
            border-radius: 50%;
            animation: bounce 0.8s infinite;
        }
        
        .dot:nth-child(2) { animation-delay: 0.2s; }
        .dot:nth-child(3) { animation-delay: 0.4s; }
        
        @keyframes bounce {
            0%, 100% { transform: translateY(0); }
            50% { transform: translateY(-5px); }
        }
        </style>
    """, unsafe_allow_html=True)

# Load model and tokenizer
@st.cache_resource
def load_model():
    # Using VietAI's Vietnamese GPT model
    model_name = "tamgrnguyen/Gemma-2-2b-it-Vietnamese-Aesthetic"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    return model, tokenizer

def generate_response(prompt, model, tokenizer, max_length=100):
    # Prepare input
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=inputs.attention_mask
        )
    
    # Decode response
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the input prompt from the response
    response = response[len(prompt):].strip()
    return response

def init_session_state():
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    if 'thinking' not in st.session_state:
        st.session_state.thinking = False

def display_chat_history():
    for message in st.session_state.messages:
        if message['role'] == 'user':
            st.markdown(f'<div class="user-message">{message["content"]}</div>', unsafe_allow_html=True)
        else:
            st.markdown(f'<div class="bot-message">{message["content"]}</div>', unsafe_allow_html=True)

def main():
    st.set_page_config(
        page_title="IOGPT",
        page_icon="🤖",
        layout="wide"
    )
    
    local_css()
    init_session_state()
    
    # Load model
    model, tokenizer = load_model()
    
    # Chat interface
    st.title("IOGPT 🤖")
    st.markdown("Xin chào! Tôi là trợ lý IOGPT. Hãy hỏi tôi bất cứ điều gì!")
    
    # Chat history container
    chat_container = st.container()
    
    # Input container
    with st.container():
        col1, col2 = st.columns([6, 1])
        with col1:
            user_input = st.text_input(
                "Nhập tin nhắn của bạn...",
                key="user_input",
                label_visibility="hidden"
            )
        with col2:
            send_button = st.button("Gửi")
    
    if user_input and send_button:
        # Add user message
        st.session_state.messages.append({"role": "user", "content": user_input})
        
        # Show thinking animation
        st.session_state.thinking = True
        
        # Prepare conversation history
        conversation_history = "\n".join([
            f"{'User: ' if msg['role'] == 'user' else 'Assistant: '}{msg['content']}"
            for msg in st.session_state.messages[-3:]  # Last 3 messages for context
        ])
        
        # Generate response
        prompt = f"{conversation_history}\nAssistant:"
        bot_response = generate_response(prompt, model, tokenizer)
        
        # Add bot response
        st.session_state.messages.append({"role": "assistant", "content": bot_response})
        st.session_state.thinking = False
        
        # Clear input and rerun
        st.rerun()
    
    # Display chat history
    with chat_container:
        display_chat_history()
        
        if st.session_state.thinking:
            st.markdown("""
                <div class="thinking-animation">
                    <div class="dot"></div>
                    <div class="dot"></div>
                    <div class="dot"></div>
                </div>
            """, unsafe_allow_html=True)

if __name__ == "__main__":
    main()