File size: 3,616 Bytes
a4c3bcc
cb4d8f6
a4c3bcc
 
985ad3e
 
 
 
 
 
 
a4c3bcc
985ad3e
 
 
a4c3bcc
985ad3e
 
 
 
 
 
 
 
 
76bb75b
 
985ad3e
 
 
 
 
 
 
 
 
76bb75b
985ad3e
 
 
 
76bb75b
 
 
a4c3bcc
 
cb4d8f6
985ad3e
 
cb4d8f6
 
985ad3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
985ad3e
 
 
 
 
76bb75b
 
985ad3e
 
a4c3bcc
985ad3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class VietnameseChatbot:
    def __init__(self, model_name="vinai/gpt-neo-1.3B-vietnamese-news"):
        """
        Initialize the Vietnamese chatbot with a pre-trained model
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        
        # Use GPU if available
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def generate_response(self, conversation_history, max_length=200):
        """
        Generate a response based on conversation history
        """
        # Combine conversation history into a single prompt
        prompt = "\n".join(conversation_history)
        
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Generate response
        outputs = self.model.generate(
            inputs.input_ids, 
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            temperature=0.7,
            top_k=50,
            top_p=0.95
        )
        
        # Decode the generated response
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the new part of the response
        response = response[len(prompt):].strip()
        
        return response

def main():
    st.set_page_config(
            page_title="IOGPT", 
        page_icon="🇻🇳",
        layout="wide"
    )

    st.title("🇻🇳 Chat với IOGPT")
    st.markdown("""
    ### Trò chuyện với IOGPT
    """)

    # Initialize chatbot
    if 'chatbot' not in st.session_state:
        try:
            st.session_state.chatbot = VietnameseChatbot()
        except Exception as e:
            st.error(f"Lỗi khởi tạo mô hình: {e}")
            return

    # Initialize conversation history
    if 'conversation' not in st.session_state:
        st.session_state.conversation = []

    # Chat interface
    with st.form(key='chat_form'):
        user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...")
        send_button = st.form_submit_button("Gửi")

    # Process user input
    if send_button and user_input:
        # Add user message to conversation
        st.session_state.conversation.append(f"User: {user_input}")
        
        try:
            # Generate AI response
            with st.spinner('Đang suy nghĩ...'):
                ai_response = st.session_state.chatbot.generate_response(
                    st.session_state.conversation
                )
            
            # Add AI response to conversation
            st.session_state.conversation.append(f"AI: {ai_response}")
        
        except Exception as e:
            st.error(f"Lỗi trong quá trình trò chuyện: {e}")

    # Display conversation history
    st.subheader("Lịch sử trò chuyện")
    for msg in st.session_state.conversation:
        if msg.startswith("User:"):
            st.markdown(f"**{msg}**")
        else:
            st.markdown(f"*{msg}*")

    # Model and usage information
    st.sidebar.header("Thông tin mô hình")
    st.sidebar.info("""
    ### Mô hình AI Tiếng Việt
    - Được huấn luyện trên dữ liệu tiếng Việt
    - Hỗ trợ trò chuyện đa dạng
    - Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình
    """)

if __name__ == "__main__":
    main()