File size: 9,200 Bytes
dc14176
a4c3bcc
 
dc14176
 
 
7f79d8b
985ad3e
80eee0f
 
 
 
 
 
 
 
 
 
 
 
 
 
dc14176
7f79d8b
dc14176
7f79d8b
dc14176
80eee0f
 
dc14176
7f79d8b
 
 
 
 
289913c
7f79d8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289913c
7f79d8b
 
 
 
 
 
 
 
 
 
 
 
 
dc14176
7f79d8b
80eee0f
289913c
7f79d8b
289913c
80eee0f
7f79d8b
289913c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f79d8b
289913c
 
7f79d8b
289913c
7f79d8b
 
dc14176
 
 
 
 
 
7f79d8b
dc14176
 
7f79d8b
dc14176
 
 
 
 
 
 
7f79d8b
dc14176
 
 
 
 
 
 
7f79d8b
dc14176
 
 
 
 
 
 
 
 
7f79d8b
dc14176
 
7f79d8b
dc14176
 
 
 
 
 
 
 
7f79d8b
dc14176
 
 
 
80eee0f
 
 
 
 
 
 
 
dc14176
7f79d8b
 
 
 
dc14176
 
7f79d8b
47ed12b
80eee0f
 
47ed12b
dc14176
 
 
47ed12b
7f79d8b
 
 
 
 
 
dc14176
 
 
 
47ed12b
dc14176
 
 
 
a4c3bcc
dc14176
 
 
47ed12b
dc14176
 
47ed12b
dc14176
 
 
47ed12b
dc14176
 
 
a4c3bcc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json

@st.cache_resource
def load_model_and_tokenizer(model_name='intfloat/multilingual-e5-small'):
    """
    Cached function to load model and tokenizer
    This ensures the model is loaded only once and reused
    """
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    print("Loading model...")
    model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
    
    return tokenizer, model

class VietnameseChatbot:
    def __init__(self, model_name='intfloat/multilingual-e5-small'):
        """
        Initialize the Vietnamese chatbot with pre-loaded model and conversation data
        """
        # Load pre-trained model and tokenizer using cached function
        self.tokenizer, self.model = load_model_and_tokenizer(model_name)
        
        # Load comprehensive conversation dataset
        self.conversation_data = self._load_conversation_data()
        
        # Pre-compute embeddings for faster response generation
        print("Pre-computing conversation embeddings...")
        self.conversation_embeddings = self._compute_embeddings()

    def _load_conversation_data(self):
        """
        Load a comprehensive conversation dataset
        """
        return [
            # Greeting conversations
            {"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"},
            {"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."},
            {"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."},
            
            # Identity and purpose
            {"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."},
            {"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
            
            # Small talk
            {"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
            {"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
            
            # Weather and time
            {"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."},
            {"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."},
            
            # Assistance offers
            {"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"},
            {"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."},
            
            # Farewell
            {"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."},
            {"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."},
        ]

    @st.cache_data
    def _compute_embeddings(queries):
        """
        Pre-compute embeddings for conversation queries
        Cached to avoid recomputing on every run
        """
        def embed_single_text(text, tokenizer, model):
            try:
                # Tokenize and generate embeddings
                inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
                
                with torch.no_grad():
                    model_output = model(**inputs)
                
                # Mean pooling
                token_embeddings = model_output[0]
                input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
                embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                
                return embeddings.numpy()[0]
            except Exception as e:
                print(f"Embedding error: {e}")
                return None

        # Import these arguments to make the function self-contained
        from transformers import AutoTokenizer, AutoModel
        tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-small')
        model = AutoModel.from_pretrained('intfloat/multilingual-e5-small', torch_dtype=torch.float16)

        embeddings = []
        for query in queries:
            embedding = embed_single_text(query['query'], tokenizer, model)
            if embedding is not None:
                embeddings.append(embedding)
        return np.array(embeddings)

    def embed_text(self, text):
        """
        Generate embeddings for input text
        """
        try:
            # Tokenize and generate embeddings
            inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
            
            with torch.no_grad():
                model_output = self.model(**inputs)
            
            # Mean pooling
            embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
            return embeddings.numpy()
        except Exception as e:
            print(f"Embedding error: {e}")
            return None

    def mean_pooling(self, model_output, attention_mask):
        """
        Perform mean pooling on model output
        """
        token_embeddings = model_output[0]
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def get_response(self, user_query):
        """
        Find the most similar response from conversation data
        """
        try:
            # Embed user query
            query_embedding = self.embed_text(user_query)
            
            if query_embedding is None:
                return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn."
            
            # Calculate cosine similarities
            similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0]
            
            # Find most similar response
            best_match_index = np.argmax(similarities)
            
            # Return response if similarity is above threshold
            if similarities[best_match_index] > 0.5:
                return self.conversation_data[best_match_index]['response']
            
            return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?"
        except Exception as e:
            print(f"Response generation error: {e}")
            return "Đã xảy ra lỗi. Xin vui lòng thử lại."

@st.cache_resource
def initialize_chatbot():
    """
    Cached function to initialize the chatbot
    This ensures the chatbot is created only once
    """
    return VietnameseChatbot()

def main():
    st.set_page_config(
        page_title="Trợ Lý AI Tiếng Việt",
        page_icon="🤖",
    )

    st.title("🤖 Trợ Lý AI Tiếng Việt")
    st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ")
    
    # Initialize chatbot using cached initialization
    chatbot = initialize_chatbot()
    
    # Chat history in session state
    if 'messages' not in st.session_state:
        st.session_state.messages = []
    
    # Sidebar for additional information
    with st.sidebar:
        st.header("Về Trợ Lý AI")
        st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.")
        st.write("Mô hình sử dụng: intfloat/multilingual-e5-small")
    
    # Display chat messages
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])
    
    # User input
    if prompt := st.chat_input("Hãy nói gì đó..."):
        # Add user message to chat history
        st.session_state.messages.append({"role": "user", "content": prompt})
        
        # Display user message
        with st.chat_message("user"):
            st.markdown(prompt)
        
        # Get chatbot response
        response = chatbot.get_response(prompt)
        
        # Display chatbot response
        with st.chat_message("assistant"):
            st.markdown(response)
        
        # Add assistant message to chat history
        st.session_state.messages.append({"role": "assistant", "content": response})

if __name__ == "__main__":
    main()