import os import streamlit as st import torch from transformers import AutoTokenizer, AutoModel import numpy as np from sklearn.metrics.pairwise import cosine_similarity import json @st.cache_resource def load_model_and_tokenizer(model_name='intfloat/multilingual-e5-small'): """ Cached function to load model and tokenizer This ensures the model is loaded only once and reused """ print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) print("Loading model...") model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16) return tokenizer, model class VietnameseChatbot: def __init__(self, model_name='intfloat/multilingual-e5-small'): """ Initialize the Vietnamese chatbot with pre-loaded model and conversation data """ # Load pre-trained model and tokenizer using cached function self.tokenizer, self.model = load_model_and_tokenizer(model_name) # Load comprehensive conversation dataset self.conversation_data = self._load_conversation_data() # Pre-compute embeddings for faster response generation print("Pre-computing conversation embeddings...") self.conversation_embeddings = self._compute_embeddings() def _load_conversation_data(self): """ Load a comprehensive conversation dataset """ return [ # Greeting conversations {"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"}, {"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."}, {"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."}, # Identity and purpose {"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."}, {"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."}, # Small talk {"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."}, {"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."}, # Weather and time {"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."}, {"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."}, # Assistance offers {"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"}, {"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."}, # Farewell {"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."}, {"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."}, ] @st.cache_data def _compute_embeddings(_self): # Add underscore to self parameter """ Pre-compute embeddings for conversation queries Cached to avoid recomputing on every run """ def embed_single_text(text, tokenizer, model): try: # Tokenize and generate embeddings inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): model_output = model(**inputs) # Mean pooling token_embeddings = model_output[0] input_mask_expanded = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float() embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) return embeddings.numpy()[0] except Exception as e: print(f"Embedding error: {e}") return None embeddings = [] for conversation in _self.conversation_data: # Use _self instead of self embedding = embed_single_text(conversation['query'], _self.tokenizer, _self.model) # Use _self instead of self if embedding is not None: embeddings.append(embedding) return np.array(embeddings) def embed_text(self, text): """ Generate embeddings for input text """ try: # Tokenize and generate embeddings inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): model_output = self.model(**inputs) # Mean pooling embeddings = self.mean_pooling(model_output, inputs['attention_mask']) return embeddings.numpy() except Exception as e: print(f"Embedding error: {e}") return None def mean_pooling(self, model_output, attention_mask): """ Perform mean pooling on model output """ token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def get_response(self, user_query): """ Find the most similar response from conversation data """ try: # Embed user query query_embedding = self.embed_text(user_query) if query_embedding is None: return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn." # Calculate cosine similarities similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0] # Find most similar response best_match_index = np.argmax(similarities) # Return response if similarity is above threshold if similarities[best_match_index] > 0.5: return self.conversation_data[best_match_index]['response'] return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?" except Exception as e: print(f"Response generation error: {e}") return "Đã xảy ra lỗi. Xin vui lòng thử lại." @st.cache_resource def initialize_chatbot(): """ Cached function to initialize the chatbot This ensures the chatbot is created only once """ return VietnameseChatbot() def main(): st.set_page_config( page_title="Trợ Lý AI Tiếng Việt", page_icon="🤖", ) st.title("🤖 Trợ Lý AI Tiếng Việt") st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ") # Initialize chatbot using cached initialization chatbot = initialize_chatbot() # Chat history in session state if 'messages' not in st.session_state: st.session_state.messages = [] # Sidebar for additional information with st.sidebar: st.header("Về Trợ Lý AI") st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.") st.write("Mô hình sử dụng: intfloat/multilingual-e5-small") # Display chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # User input if prompt := st.chat_input("Hãy nói gì đó..."): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) # Display user message with st.chat_message("user"): st.markdown(prompt) # Get chatbot response response = chatbot.get_response(prompt) # Display chatbot response with st.chat_message("assistant"): st.markdown(response) # Add assistant message to chat history st.session_state.messages.append({"role": "assistant", "content": response}) if __name__ == "__main__": main()