File size: 3,676 Bytes
a4c3bcc
7b473e5
cb4d8f6
7b473e5
 
faa133e
 
7b473e5
a4c3bcc
 
985ad3e
55e2485
985ad3e
 
 
faa133e
 
a4c3bcc
985ad3e
 
 
a4c3bcc
985ad3e
 
 
 
 
 
 
 
 
76bb75b
 
985ad3e
 
 
 
 
 
 
 
 
76bb75b
985ad3e
 
 
 
76bb75b
 
 
a4c3bcc
 
faa133e
985ad3e
faa133e
985ad3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
985ad3e
 
 
 
 
76bb75b
 
985ad3e
 
a4c3bcc
985ad3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c3bcc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login

sec_token=os.getenv("HF_TOKEN")
login(token=sec_token)

import torch

class VietnameseChatbot:
    def __init__(self, model_name="meta-llama/Llama-2-13b-hf"):
        """
        Initialize the Vietnamese chatbot with a pre-trained model
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=sec_token)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, token=sec_token)
        
        # Use GPU if available
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def generate_response(self, conversation_history, max_length=200):
        """
        Generate a response based on conversation history
        """
        # Combine conversation history into a single prompt
        prompt = "\n".join(conversation_history)
        
        # Tokenize input
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Generate response
        outputs = self.model.generate(
            inputs.input_ids, 
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            temperature=0.7,
            top_k=50,
            top_p=0.95
        )
        
        # Decode the generated response
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the new part of the response
        response = response[len(prompt):].strip()
        
        return response

def main():
    st.set_page_config(page_title="IOGPT", layout="wide")

    st.title("Chat với IOGPT")
    st.markdown("""
    ### Trò chuyện với IOGPT
    """)

    # Initialize chatbot
    if 'chatbot' not in st.session_state:
        try:
            st.session_state.chatbot = VietnameseChatbot()
        except Exception as e:
            st.error(f"Lỗi khởi tạo mô hình: {e}")
            return

    # Initialize conversation history
    if 'conversation' not in st.session_state:
        st.session_state.conversation = []

    # Chat interface
    with st.form(key='chat_form'):
        user_input = st.text_input("Nhập tin nhắn của bạn:", placeholder="Viết tin nhắn...")
        send_button = st.form_submit_button("Gửi")

    # Process user input
    if send_button and user_input:
        # Add user message to conversation
        st.session_state.conversation.append(f"User: {user_input}")
        
        try:
            # Generate AI response
            with st.spinner('Đang suy nghĩ...'):
                ai_response = st.session_state.chatbot.generate_response(
                    st.session_state.conversation
                )
            
            # Add AI response to conversation
            st.session_state.conversation.append(f"AI: {ai_response}")
        
        except Exception as e:
            st.error(f"Lỗi trong quá trình trò chuyện: {e}")

    # Display conversation history
    st.subheader("Lịch sử trò chuyện")
    for msg in st.session_state.conversation:
        if msg.startswith("User:"):
            st.markdown(f"**{msg}**")
        else:
            st.markdown(f"*{msg}*")

    # Model and usage information
    st.sidebar.header("Thông tin mô hình")
    st.sidebar.info("""
    ### Mô hình AI Tiếng Việt
    - Được huấn luyện trên dữ liệu tiếng Việt
    - Hỗ trợ trò chuyện đa dạng
    - Lưu ý: Chất lượng trả lời phụ thuộc vào mô hình
    """)

if __name__ == "__main__":
    main()