Spaces:
Sleeping
Sleeping
File size: 3,616 Bytes
a4c3bcc cb4d8f6 a4c3bcc 985ad3e a4c3bcc 985ad3e a4c3bcc 985ad3e 76bb75b 985ad3e 76bb75b 985ad3e 76bb75b a4c3bcc cb4d8f6 985ad3e cb4d8f6 985ad3e a4c3bcc 985ad3e 76bb75b 985ad3e a4c3bcc 985ad3e a4c3bcc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class VietnameseChatbot:
def __init__(self, model_name="vinai/gpt-neo-1.3B-vietnamese-news"):
"""
Initialize the Vietnamese chatbot with a pre-trained model
"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
# Use GPU if available
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def generate_response(self, conversation_history, max_length=200):
"""
Generate a response based on conversation history
"""
# Combine conversation history into a single prompt
prompt = "\n".join(conversation_history)
# Tokenize input
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Generate response
outputs = self.model.generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.7,
top_k=50,
top_p=0.95
)
# Decode the generated response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the new part of the response
response = response[len(prompt):].strip()
return response
def main():
st.set_page_config(
page_title="IOGPT",
page_icon="🇻🇳",
layout="wide"
)
st.title("🇻🇳 Chat với IOGPT")
st.markdown("""
### Trò chuyện với IOGPT
""")
# Initialize chatbot
if 'chatbot' not in st.session_state:
try:
st.session_state.chatbot = VietnameseChatbot()
except Exception as e:
st.error(f"Lỗi khởi tạo mô hình: {e}")
return
# Initialize conversation history
if 'conversation' not in st.session_state:
st.session_state.conversation = []
# Chat interface
with st.form(key='chat_form'):
user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...")
send_button = st.form_submit_button("Gửi")
# Process user input
if send_button and user_input:
# Add user message to conversation
st.session_state.conversation.append(f"User: {user_input}")
try:
# Generate AI response
with st.spinner('Đang suy nghĩ...'):
ai_response = st.session_state.chatbot.generate_response(
st.session_state.conversation
)
# Add AI response to conversation
st.session_state.conversation.append(f"AI: {ai_response}")
except Exception as e:
st.error(f"Lỗi trong quá trình trò chuyện: {e}")
# Display conversation history
st.subheader("Lịch sử trò chuyện")
for msg in st.session_state.conversation:
if msg.startswith("User:"):
st.markdown(f"**{msg}**")
else:
st.markdown(f"*{msg}*")
# Model and usage information
st.sidebar.header("Thông tin mô hình")
st.sidebar.info("""
### Mô hình AI Tiếng Việt
- Được huấn luyện trên dữ liệu tiếng Việt
- Hỗ trợ trò chuyện đa dạng
- Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình
""")
if __name__ == "__main__":
main() |