IOTraining / app.py
JustKiddo's picture
Update app.py
dc14176 verified
raw
history blame
5.11 kB
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Get the port from Heroku environment, default to 8501 for local development
PORT = int(os.environ.get('PORT', 8501))
class LazyLoadModel:
def __init__(self, model_name='intfloat/multilingual-e5-small'):
self.model_name = model_name
self._tokenizer = None
self._model = None
@property
def tokenizer(self):
if self._tokenizer is None:
print("Loading tokenizer...")
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return self._tokenizer
@property
def model(self):
if self._model is None:
print("Loading model...")
# Use float16 to reduce memory and potentially speed up loading
self._model = AutoModel.from_pretrained(self.model_name, torch_dtype=torch.float16)
return self._model
class VietnameseChatbot:
def __init__(self):
"""
Initialize the Vietnamese chatbot with lazy-loaded model
"""
self.model_loader = LazyLoadModel()
# Very minimal conversation data to reduce startup time
self.conversation_data = [
{"query": "Xin chào", "response": "Chào bạn!"},
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI."},
]
def embed_text(self, text):
"""
Generate embeddings for input text
"""
try:
# Tokenize and generate embeddings
inputs = self.model_loader.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = self.model_loader.model(**inputs)
# Mean pooling
embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
return embeddings.numpy()
except Exception as e:
print(f"Embedding error: {e}")
return None
def mean_pooling(self, model_output, attention_mask):
"""
Perform mean pooling on model output
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_response(self, user_query):
"""
Find the most similar response from conversation data
"""
try:
# Embed user query
query_embedding = self.embed_text(user_query)
if query_embedding is None:
return "Xin lỗi, đã có lỗi xảy ra."
# Embed conversation data
conversation_embeddings = np.array([
self.embed_text(item['query'])[0] for item in self.conversation_data
])
# Calculate cosine similarities
similarities = cosine_similarity(query_embedding, conversation_embeddings)[0]
# Find most similar response
best_match_index = np.argmax(similarities)
# Return response if similarity is above threshold
if similarities[best_match_index] > 0.5:
return self.conversation_data[best_match_index]['response']
return "Xin lỗi, tôi không hiểu câu hỏi của bạn."
except Exception as e:
print(f"Response generation error: {e}")
return "Đã xảy ra lỗi. Xin vui lòng thử lại."
def main():
# Server configuration to use Heroku-assigned port
if 'PORT' in os.environ:
#st.set_option('server.port', PORT)
print(f"Server starting on port {PORT}")
st.title("🤖 Trợ Lý AI Tiếng Việt")
# Initialize chatbot
chatbot = VietnameseChatbot()
# Chat history in session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Hãy nói gì đó..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get chatbot response
response = chatbot.get_response(prompt)
# Display chatbot response
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Logging for Heroku diagnostics
print("Chatbot application is initializing...")
if __name__ == "__main__":
main()