File size: 3,617 Bytes
a4c3bcc
7b473e5
c99097e
7b473e5
a4c3bcc
 
985ad3e
1934f9e
985ad3e
 
 
2650814
1934f9e
a4c3bcc
985ad3e
 
 
a4c3bcc
985ad3e
 
 
 
 
 
 
 
 
76bb75b
 
985ad3e
 
 
 
 
 
 
 
 
76bb75b
985ad3e
 
 
 
76bb75b
 
 
a4c3bcc
 
faa133e
985ad3e
faa133e
985ad3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
985ad3e
 
 
 
 
76bb75b
 
985ad3e
 
a4c3bcc
985ad3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
import os
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from huggingface_hub import login
import torch

class VietnameseChatbot:
    def __init__(self, model_name="tamgrnguyen/Gemma-2-2b-it-Vietnamese-Aesthetic"):
        """
        Initialize the Vietnamese chatbot with a pre-trained model
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        
        # Use GPU if available
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def generate_response(self, conversation_history, max_length=200):
        """
        Generate a response based on conversation history
        """
        # Combine conversation history into a single prompt
        prompt = "\n".join(conversation_history)
        
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Generate response
        outputs = self.model.generate(
            inputs.input_ids, 
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            temperature=0.7,
            top_k=50,
            top_p=0.95
        )
        
        # Decode the generated response
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the new part of the response
        response = response[len(prompt):].strip()
        
        return response

def main():
    st.set_page_config(page_title="IOGPT", layout="wide")

    st.title("Chat với IOGPT")
    st.markdown("""
    ### Trò chuyện với IOGPT
    """)

    # Initialize chatbot
    if 'chatbot' not in st.session_state:
        try:
            st.session_state.chatbot = VietnameseChatbot()
        except Exception as e:
            st.error(f"Lỗi khởi tạo mô hình: {e}")
            return

    # Initialize conversation history
    if 'conversation' not in st.session_state:
        st.session_state.conversation = []

    # Chat interface
    with st.form(key='chat_form'):
        user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...")
        send_button = st.form_submit_button("Gửi")

    # Process user input
    if send_button and user_input:
        # Add user message to conversation
        st.session_state.conversation.append(f"User: {user_input}")
        
        try:
            # Generate AI response
            with st.spinner('Đang suy nghĩ...'):
                ai_response = st.session_state.chatbot.generate_response(
                    st.session_state.conversation
                )
            
            # Add AI response to conversation
            st.session_state.conversation.append(f"AI: {ai_response}")
        
        except Exception as e:
            st.error(f"Lỗi trong quá trình trò chuyện: {e}")

    # Display conversation history
    st.subheader("Lịch sử trò chuyện")
    for msg in st.session_state.conversation:
        if msg.startswith("User:"):
            st.markdown(f"**{msg}**")
        else:
            st.markdown(f"*{msg}*")

    # Model and usage information
    st.sidebar.header("Thông tin mô hình")
    st.sidebar.info("""
    ### Mô hình AI Tiếng Việt
    - Được huấn luyện trên dữ liệu tiếng Việt
    - Hỗ trợ trò chuyện đa dạng
    - Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình
    """)

if __name__ == "__main__":
    main()