Spaces:
Running
Running
import os | |
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
import json | |
class VietnameseChatbot: | |
def __init__(self, model_name='intfloat/multilingual-e5-small'): | |
""" | |
Initialize the Vietnamese chatbot with pre-loaded model and conversation data | |
""" | |
# Load pre-trained model and tokenizer | |
print("Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
print("Loading model...") | |
self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16) | |
# Load comprehensive conversation dataset | |
self.conversation_data = self._load_conversation_data() | |
# Pre-compute embeddings for faster response generation | |
print("Pre-computing conversation embeddings...") | |
self.conversation_embeddings = self._precompute_embeddings() | |
def _load_conversation_data(self): | |
""" | |
Load a comprehensive conversation dataset | |
""" | |
return [ | |
# Greeting conversations | |
{"query": "Xin chào", "response": "Chào bạn! Tôi có thể giúp gì cho bạn?"}, | |
{"query": "Hi", "response": "Xin chào! Tôi là trợ lý AI tiếng Việt."}, | |
{"query": "Chào buổi sáng", "response": "Chào buổi sáng! Chúc bạn một ngày tốt lành."}, | |
# Identity and purpose | |
{"query": "Bạn là ai?", "response": "Tôi là trợ lý AI được phát triển để hỗ trợ và trò chuyện bằng tiếng Việt."}, | |
{"query": "Bạn từ đâu đến?", "response": "Tôi được phát triển bởi một nhóm kỹ sư AI, và tôn chỉ của tôi là hỗ trợ con người."}, | |
# Small talk | |
{"query": "Bạn thích gì?", "response": "Tôi thích học hỏi và giúp đỡ mọi người. Mỗi cuộc trò chuyện là một cơ hội để tôi phát triển."}, | |
{"query": "Bạn có thể làm gì?", "response": "Tôi có thể trò chuyện, trả lời câu hỏi, và hỗ trợ bạn trong nhiều tình huống khác nhau."}, | |
# Weather and time | |
{"query": "Thời tiết hôm nay thế nào?", "response": "Xin lỗi, tôi không thể cung cấp thông tin thời tiết trực tiếp. Bạn có thể kiểm tra ứng dụng dự báo thời tiết."}, | |
{"query": "Bây giờ là mấy giờ?", "response": "Tôi là trợ lý AI, nên không thể xem đồng hồ. Bạn có thể kiểm tra thiết bị của mình."}, | |
# Assistance offers | |
{"query": "Tôi cần trợ giúp", "response": "Tôi sẵn sàng hỗ trợ bạn. Bạn cần giúp gì?"}, | |
{"query": "Giúp tôi với cái gì đó", "response": "Vâng, tôi có thể hỗ trợ bạn. Hãy cho tôi biết chi tiết hơn."}, | |
# Farewell | |
{"query": "Tạm biệt", "response": "Hẹn gặp lại! Chúc bạn một ngày tốt đẹp."}, | |
{"query": "Bye", "response": "Tạm biệt! Rất vui được trò chuyện với bạn."}, | |
] | |
def _precompute_embeddings(self): | |
""" | |
Pre-compute embeddings for all conversation queries | |
""" | |
embeddings = [] | |
for item in self.conversation_data: | |
embedding = self.embed_text(item['query']) | |
if embedding is not None: | |
embeddings.append(embedding[0]) | |
return np.array(embeddings) | |
def embed_text(self, text): | |
""" | |
Generate embeddings for input text | |
""" | |
try: | |
# Tokenize and generate embeddings | |
inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True) | |
with torch.no_grad(): | |
model_output = self.model(**inputs) | |
# Mean pooling | |
embeddings = self.mean_pooling(model_output, inputs['attention_mask']) | |
return embeddings.numpy() | |
except Exception as e: | |
print(f"Embedding error: {e}") | |
return None | |
def mean_pooling(self, model_output, attention_mask): | |
""" | |
Perform mean pooling on model output | |
""" | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def get_response(self, user_query): | |
""" | |
Find the most similar response from conversation data | |
""" | |
try: | |
# Embed user query | |
query_embedding = self.embed_text(user_query) | |
if query_embedding is None: | |
return "Xin lỗi, đã có lỗi xảy ra khi phân tích câu hỏi của bạn." | |
# Calculate cosine similarities | |
similarities = cosine_similarity(query_embedding, self.conversation_embeddings)[0] | |
# Find most similar response | |
best_match_index = np.argmax(similarities) | |
# Return response if similarity is above threshold | |
if similarities[best_match_index] > 0.5: | |
return self.conversation_data[best_match_index]['response'] | |
return "Xin lỗi, tôi chưa hiểu rõ câu hỏi của bạn. Bạn có thể diễn đạt lại được không?" | |
except Exception as e: | |
print(f"Response generation error: {e}") | |
return "Đã xảy ra lỗi. Xin vui lòng thử lại." | |
def main(): | |
st.set_page_config( | |
page_title="Trợ Lý AI Tiếng Việt", | |
page_icon="🤖", | |
) | |
st.title("🤖 Trợ Lý AI Tiếng Việt") | |
st.caption("Trò chuyện với trợ lý AI được phát triển bằng mô hình đa ngôn ngữ") | |
# Initialize chatbot (this will pre-load models and embeddings) | |
chatbot = VietnameseChatbot() | |
# Chat history in session state | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Sidebar for additional information | |
with st.sidebar: | |
st.header("Về Trợ Lý AI") | |
st.write("Đây là một trợ lý AI được phát triển để hỗ trợ trò chuyện bằng tiếng Việt.") | |
st.write("Mô hình sử dụng: intfloat/multilingual-e5-small") | |
# Display chat messages | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# User input | |
if prompt := st.chat_input("Hãy nói gì đó..."): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Get chatbot response | |
response = chatbot.get_response(prompt) | |
# Display chatbot response | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant message to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
if __name__ == "__main__": | |
main() |