Spaces:
Running
Running
File size: 5,113 Bytes
dc14176 a4c3bcc dc14176 a4c3bcc dc14176 a4c3bcc dc14176 985ad3e dc14176 985ad3e dc14176 985ad3e dc14176 47ed12b dc14176 47ed12b dc14176 47ed12b dc14176 47ed12b dc14176 47ed12b dc14176 47ed12b dc14176 a4c3bcc dc14176 47ed12b dc14176 47ed12b dc14176 47ed12b dc14176 a4c3bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
# Get the port from Heroku environment, default to 8501 for local development
PORT = int(os.environ.get('PORT', 8501))
class LazyLoadModel:
def __init__(self, model_name='intfloat/multilingual-e5-small'):
self.model_name = model_name
self._tokenizer = None
self._model = None
@property
def tokenizer(self):
if self._tokenizer is None:
print("Loading tokenizer...")
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name)
return self._tokenizer
@property
def model(self):
if self._model is None:
print("Loading model...")
# Use float16 to reduce memory and potentially speed up loading
self._model = AutoModel.from_pretrained(self.model_name, torch_dtype=torch.float16)
return self._model
class VietnameseChatbot:
def __init__(self):
"""
Initialize the Vietnamese chatbot with lazy-loaded model
"""
self.model_loader = LazyLoadModel()
# Very minimal conversation data to reduce startup time
self.conversation_data = [
{"query": "Xin chào", "response": "Chào bạn!"},
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI."},
]
def embed_text(self, text):
"""
Generate embeddings for input text
"""
try:
# Tokenize and generate embeddings
inputs = self.model_loader.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = self.model_loader.model(**inputs)
# Mean pooling
embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
return embeddings.numpy()
except Exception as e:
print(f"Embedding error: {e}")
return None
def mean_pooling(self, model_output, attention_mask):
"""
Perform mean pooling on model output
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_response(self, user_query):
"""
Find the most similar response from conversation data
"""
try:
# Embed user query
query_embedding = self.embed_text(user_query)
if query_embedding is None:
return "Xin lỗi, đã có lỗi xảy ra."
# Embed conversation data
conversation_embeddings = np.array([
self.embed_text(item['query'])[0] for item in self.conversation_data
])
# Calculate cosine similarities
similarities = cosine_similarity(query_embedding, conversation_embeddings)[0]
# Find most similar response
best_match_index = np.argmax(similarities)
# Return response if similarity is above threshold
if similarities[best_match_index] > 0.5:
return self.conversation_data[best_match_index]['response']
return "Xin lỗi, tôi không hiểu câu hỏi của bạn."
except Exception as e:
print(f"Response generation error: {e}")
return "Đã xảy ra lỗi. Xin vui lòng thử lại."
def main():
# Server configuration to use Heroku-assigned port
if 'PORT' in os.environ:
#st.set_option('server.port', PORT)
print(f"Server starting on port {PORT}")
st.title("🤖 Trợ Lý AI Tiếng Việt")
# Initialize chatbot
chatbot = VietnameseChatbot()
# Chat history in session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Hãy nói gì đó..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get chatbot response
response = chatbot.get_response(prompt)
# Display chatbot response
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Logging for Heroku diagnostics
print("Chatbot application is initializing...")
if __name__ == "__main__":
main() |