IOTraining / app.py
JustKiddo's picture
Update app.py
80eee0f verified
raw
history blame
8.04 kB
import os
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import json
@st.cache_resource
def load_model_and_tokenizer(model_name='intfloat/multilingual-e5-small'):
"""
Cached function to load model and tokenizer
This ensures the model is loaded only once and reused
"""
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)
return tokenizer, model
class VietnameseChatbot:
def __init__(self, model_name='intfloat/multilingual-e5-small'):
"""
Initialize the Vietnamese chatbot with pre-loaded model and conversation data
"""
# Load pre-trained model and tokenizer using cached function
self.tokenizer, self.model = load_model_and_tokenizer(model_name)
# Load comprehensive conversation dataset
self.conversation_data = self._load_conversation_data()
# Pre-compute embeddings for faster response generation
print("Pre-computing conversation embeddings...")
self.conversation_embeddings = self._precompute_embeddings()
def _load_conversation_data(self):
"""
Load a comprehensive conversation dataset
"""
return [
# Greeting conversations
{"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"},
{"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."},
{"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."},
# Identity and purpose
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."},
{"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."},
# Small talk
{"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giú đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."},
{"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."},
# Weather and time
{"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."},
{"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."},
# Assistance offers
{"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"},
{"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."},
# Farewell
{"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."},
{"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."},
]
@st.cache_data
def _precompute_embeddings(self):
"""
Pre-compute embeddings for all conversation queries
Cached to avoid recomputing on every run
"""
embeddings = []
for item in self.conversation_data:
embedding = self.embed_text(item['query'])
if embedding is not None:
embeddings.append(embedding[0])
return np.array(embeddings)
def embed_text(self, text):
"""
Generate embeddings for input text
"""
try:
# Tokenize and generate embeddings
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
model_output = self.model(**inputs)
# Mean pooling
embeddings = self.mean_pooling(model_output, inputs['attention_mask'])
return embeddings.numpy()
except Exception as e:
print(f"Embedding error: {e}")
return None
def mean_pooling(self, model_output, attention_mask):
"""
Perform mean pooling on model output
"""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def get_response(self, user_query):
"""
Find the most similar response from conversation data
"""
try:
# Embed user query
query_embedding = self.embed_text(user_query)
if query_embedding is None:
return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn."
# Calculate cosine similarities
similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0]
# Find most similar response
best_match_index = np.argmax(similarities)
# Return response if similarity is above threshold
if similarities[best_match_index] > 0.5:
return self.conversation_data[best_match_index]['response']
return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?"
except Exception as e:
print(f"Response generation error: {e}")
return "Đã xảy ra lỗi. Xin vui lòng thử lại."
@st.cache_resource
def initialize_chatbot():
"""
Cached function to initialize the chatbot
This ensures the chatbot is created only once
"""
return VietnameseChatbot()
def main():
st.set_page_config(
page_title="Trợ Lý AI Tiếng Việt",
page_icon="🤖",
)
st.title("🤖 Trợ Lý AI Tiếng Việt")
st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ")
# Initialize chatbot using cached initialization
chatbot = initialize_chatbot()
# Chat history in session state
if 'messages' not in st.session_state:
st.session_state.messages = []
# Sidebar for additional information
with st.sidebar:
st.header("Về Trợ Lý AI")
st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.")
st.write("Mô hình sử dụng: intfloat/multilingual-e5-small")
# Display chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# User input
if prompt := st.chat_input("Hãy nói gì đó..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Get chatbot response
response = chatbot.get_response(prompt)
# Display chatbot response
with st.chat_message("assistant"):
st.markdown(response)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
if __name__ == "__main__":
main()