Spaces:
Running
Running
import os | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Get the port from Heroku environment, default to 8501 for local development | |
PORT = int(os.environ.get('PORT', 8501)) | |
class LazyLoadModel: | |
def __init__(self, model_name='intfloat/multilingual-e5-small'): | |
self.model_name = model_name | |
self._tokenizer = None | |
self._model = None | |
def tokenizer(self): | |
if self._tokenizer is None: | |
print("Loading tokenizer...") | |
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
return self._tokenizer | |
def model(self): | |
if self._model is None: | |
print("Loading model...") | |
# Use float16 to reduce memory and potentially speed up loading | |
self._model = AutoModel.from_pretrained(self.model_name, torch_dtype=torch.float16) | |
return self._model | |
class VietnameseChatbot: | |
def __init__(self): | |
""" | |
Initialize the Vietnamese chatbot with lazy-loaded model | |
""" | |
self.model_loader = LazyLoadModel() | |
# Very minimal conversation data to reduce startup time | |
self.conversation_data = [ | |
{"query": "Xin chào", "response": "Chào bạn!"}, | |
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI."}, | |
] | |
def embed_text(self, text): | |
""" | |
Generate embeddings for input text | |
""" | |
try: | |
# Tokenize and generate embeddings | |
inputs = self.model_loader.tokenizer(text, return_tensors='pt', padding=True, truncation=True) | |
with torch.no_grad(): | |
model_output = self.model_loader.model(**inputs) | |
# Mean pooling | |
embeddings = self.mean_pooling(model_output, inputs['attention_mask']) | |
return embeddings.numpy() | |
except Exception as e: | |
print(f"Embedding error: {e}") | |
return None | |
def mean_pooling(self, model_output, attention_mask): | |
""" | |
Perform mean pooling on model output | |
""" | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def get_response(self, user_query): | |
""" | |
Find the most similar response from conversation data | |
""" | |
try: | |
# Embed user query | |
query_embedding = self.embed_text(user_query) | |
if query_embedding is None: | |
return "Xin lỗi, đã có lỗi xảy ra." | |
# Embed conversation data | |
conversation_embeddings = np.array([ | |
self.embed_text(item['query'])[0] for item in self.conversation_data | |
]) | |
# Calculate cosine similarities | |
similarities = cosine_similarity(query_embedding, conversation_embeddings)[0] | |
# Find most similar response | |
best_match_index = np.argmax(similarities) | |
# Return response if similarity is above threshold | |
if similarities[best_match_index] > 0.5: | |
return self.conversation_data[best_match_index]['response'] | |
return "Xin lỗi, tôi không hiểu câu hỏi của bạn." | |
except Exception as e: | |
print(f"Response generation error: {e}") | |
return "Đã xảy ra lỗi. Xin vui lòng thử lại." | |
def main(): | |
# Server configuration to use Heroku-assigned port | |
if 'PORT' in os.environ: | |
#st.set_option('server.port', PORT) | |
print(f"Server starting on port {PORT}") | |
st.title("🤖 Trợ Lý AI Tiếng Việt") | |
# Initialize chatbot | |
chatbot = VietnameseChatbot() | |
# Chat history in session state | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input | |
if prompt := st.chat_input("Hãy nói gì đó..."): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Get chatbot response | |
response = chatbot.get_response(prompt) | |
# Display chatbot response | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant message to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Logging for Heroku diagnostics | |
print("Chatbot application is initializing...") | |
if __name__ == "__main__": | |
main() |