IOTraining / app.py
JustKiddo's picture
Update app.py
985ad3e verified
raw
history blame
3.62 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class VietnameseChatbot:
def __init__(self, model_name="vinai/gpt-neo-1.3B-vietnamese-news"):
"""
Initialize the Vietnamese chatbot with a pre-trained model
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
# Use GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def generate_response(self, conversation_history, max_length=200):
"""
Generate a response based on conversation history
"""
# Combine conversation history into a single prompt
prompt = "\n".join(conversation_history)
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Generate response
outputs = self.model.generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95
)
# Decode the generated response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new part of the response
response = response[len(prompt):].strip()
return response
def main():
st.set_page_config(
page_title="IOGPT",
page_icon="🇻🇳",
layout="wide"
)
st.title("🇻🇳 Chat với IOGPT")
st.markdown("""
### Trò chuyện với IOGPT
""")
# Initialize chatbot
if 'chatbot' not in st.session_state:
try:
st.session_state.chatbot = VietnameseChatbot()
except Exception as e:
st.error(f"Lỗi khởi tạo mô hình: {e}")
return
# Initialize conversation history
if 'conversation' not in st.session_state:
st.session_state.conversation = []
# Chat interface
with st.form(key='chat_form'):
user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...")
send_button = st.form_submit_button("Gửi")
# Process user input
if send_button and user_input:
# Add user message to conversation
st.session_state.conversation.append(f"User: {user_input}")
try:
# Generate AI response
with st.spinner('Đang suy nghĩ...'):
ai_response = st.session_state.chatbot.generate_response(
st.session_state.conversation
)
# Add AI response to conversation
st.session_state.conversation.append(f"AI: {ai_response}")
except Exception as e:
st.error(f"Lỗi trong quá trình trò chuyện: {e}")
# Display conversation history
st.subheader("Lịch sử trò chuyện")
for msg in st.session_state.conversation:
if msg.startswith("User:"):
st.markdown(f"**{msg}**")
else:
st.markdown(f"*{msg}*")
# Model and usage information
st.sidebar.header("Thông tin mô hình")
st.sidebar.info("""
### Mô hình AI Tiếng Việt
- Được huấn luyện trên dữ liệu tiếng Việt
- Hỗ trợ trò chuyện đa dạng
- Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình
""")
if __name__ == "__main__":
main()