Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
class VietnameseChatbot: | |
def __init__(self, model_name="vinai/gpt-neo-1.3B-vietnamese-news"): | |
""" | |
Initialize the Vietnamese chatbot with a pre-trained model | |
""" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name) | |
# Use GPU if available | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model.to(self.device) | |
def generate_response(self, conversation_history, max_length=200): | |
""" | |
Generate a response based on conversation history | |
""" | |
# Combine conversation history into a single prompt | |
prompt = "\n".join(conversation_history) | |
# Tokenize input | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
# Generate response | |
outputs = self.model.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
no_repeat_ngram_size=2, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95 | |
) | |
# Decode the generated response | |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the new part of the response | |
response = response[len(prompt):].strip() | |
return response | |
def main(): | |
st.set_page_config( | |
page_title="IOGPT", | |
page_icon="🇻🇳", | |
layout="wide" | |
) | |
st.title("🇻🇳 Chat với IOGPT") | |
st.markdown(""" | |
### Trò chuyện với IOGPT | |
""") | |
# Initialize chatbot | |
if 'chatbot' not in st.session_state: | |
try: | |
st.session_state.chatbot = VietnameseChatbot() | |
except Exception as e: | |
st.error(f"Lỗi khởi tạo mô hình: {e}") | |
return | |
# Initialize conversation history | |
if 'conversation' not in st.session_state: | |
st.session_state.conversation = [] | |
# Chat interface | |
with st.form(key='chat_form'): | |
user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...") | |
send_button = st.form_submit_button("Gửi") | |
# Process user input | |
if send_button and user_input: | |
# Add user message to conversation | |
st.session_state.conversation.append(f"User: {user_input}") | |
try: | |
# Generate AI response | |
with st.spinner('Đang suy nghĩ...'): | |
ai_response = st.session_state.chatbot.generate_response( | |
st.session_state.conversation | |
) | |
# Add AI response to conversation | |
st.session_state.conversation.append(f"AI: {ai_response}") | |
except Exception as e: | |
st.error(f"Lỗi trong quá trình trò chuyện: {e}") | |
# Display conversation history | |
st.subheader("Lịch sử trò chuyện") | |
for msg in st.session_state.conversation: | |
if msg.startswith("User:"): | |
st.markdown(f"**{msg}**") | |
else: | |
st.markdown(f"*{msg}*") | |
# Model and usage information | |
st.sidebar.header("Thông tin mô hình") | |
st.sidebar.info(""" | |
### Mô hình AI Tiếng Việt | |
- Được huấn luyện trên dữ liệu tiếng Việt | |
- Hỗ trợ trò chuyện đa dạng | |
- Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình | |
""") | |
if __name__ == "__main__": | |
main() |