import os import streamlit as st import torch from transformers import AutoTokenizer, AutoModel import numpy as np from sklearn.metrics.pairwise import cosine_similarity # Get the port from Heroku environment, default to 8501 for local development PORT = int(os.environ.get('PORT', 8501)) class LazyLoadModel: def __init__(self, model_name='intfloat/multilingual-e5-small'): self.model_name = model_name self._tokenizer = None self._model = None @property def tokenizer(self): if self._tokenizer is None: print("Loading tokenizer...") self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) return self._tokenizer @property def model(self): if self._model is None: print("Loading model...") # Use float16 to reduce memory and potentially speed up loading self._model = AutoModel.from_pretrained(self.model_name, torch_dtype=torch.float16) return self._model class VietnameseChatbot: def __init__(self): """ Initialize the Vietnamese chatbot with lazy-loaded model """ self.model_loader = LazyLoadModel() # Very minimal conversation data to reduce startup time self.conversation_data = [ {"query": "Xin chào", "response": "Chào bạn!"}, {"query": "Bạn là ai?", "response": "Tôi là trợ lý AI."}, ] def embed_text(self, text): """ Generate embeddings for input text """ try: # Tokenize and generate embeddings inputs = self.model_loader.tokenizer(text, return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): model_output = self.model_loader.model(**inputs) # Mean pooling embeddings = self.mean_pooling(model_output, inputs['attention_mask']) return embeddings.numpy() except Exception as e: print(f"Embedding error: {e}") return None def mean_pooling(self, model_output, attention_mask): """ Perform mean pooling on model output """ token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def get_response(self, user_query): """ Find the most similar response from conversation data """ try: # Embed user query query_embedding = self.embed_text(user_query) if query_embedding is None: return "Xin lỗi, đã có lỗi xảy ra." # Embed conversation data conversation_embeddings = np.array([ self.embed_text(item['query'])[0] for item in self.conversation_data ]) # Calculate cosine similarities similarities = cosine_similarity(query_embedding, conversation_embeddings)[0] # Find most similar response best_match_index = np.argmax(similarities) # Return response if similarity is above threshold if similarities[best_match_index] > 0.5: return self.conversation_data[best_match_index]['response'] return "Xin lỗi, tôi không hiểu câu hỏi của bạn." except Exception as e: print(f"Response generation error: {e}") return "Đã xảy ra lỗi. Xin vui lòng thử lại." def main(): # Server configuration to use Heroku-assigned port if 'PORT' in os.environ: #st.set_option('server.port', PORT) print(f"Server starting on port {PORT}") st.title("🤖 Trợ Lý AI Tiếng Việt") # Initialize chatbot chatbot = VietnameseChatbot() # Chat history in session state if 'messages' not in st.session_state: st.session_state.messages = [] # Display chat messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # User input if prompt := st.chat_input("Hãy nói gì đó..."): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) # Display user message with st.chat_message("user"): st.markdown(prompt) # Get chatbot response response = chatbot.get_response(prompt) # Display chatbot response with st.chat_message("assistant"): st.markdown(response) # Add assistant message to chat history st.session_state.messages.append({"role": "assistant", "content": response}) # Logging for Heroku diagnostics print("Chatbot application is initializing...") if __name__ == "__main__": main()