File size: 5,472 Bytes
e6868fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
from openai import OpenAI
import time
import html

def predict(message, history, character, api_key, progress=gr.Progress()):
    client = OpenAI(api_key=api_key)
    history_openai_format = []
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})

    response = client.chat.completions.create(
        model='gpt-4',
        messages=history_openai_format,
        temperature=1.0,
        stream=True
    )

    partial_message = ""
    for chunk in progress.tqdm(response, desc="Generating"):
        if chunk.choices[0].delta.content:
            partial_message += chunk.choices[0].delta.content
            yield partial_message
        time.sleep(0.01)

def format_history(history):
    html_content = ""
    for human, ai in history:
        human_formatted = html.escape(human).replace('\n', '<br>')
        html_content += f'<div class="message user-message"><strong>You:</strong> {human_formatted}</div>'
        if ai:
            ai_formatted = html.escape(ai).replace('\n', '<br>')
            html_content += f'<div class="message ai-message"><strong>AI:</strong> {ai_formatted}</div>'
    return html_content

css = """
#chat-display {
    height: 600px;
    overflow-y: auto;
    border: 1px solid #ccc;
    padding: 10px;
    margin-bottom: 10px;
}
.message {
    margin-bottom: 10px;
    word-wrap: break-word;
    overflow-wrap: break-word;
}
.user-message, .ai-message {
    padding: 5px;
    border-radius: 5px;
    max-height: 300px;
    overflow-y: auto;
}
.user-message {
    background-color: #e6f3ff;
}
.ai-message {
    background-color: #f0f0f0;
}
"""

js = """
let lastScrollHeight = 0;
let lastScrollTop = 0;
let isNearBottom = true;

function updateScroll() {
    const chatDisplay = document.getElementById('chat-display');
    if (!chatDisplay) return;

    const newScrollHeight = chatDisplay.scrollHeight;
    const scrollDifference = newScrollHeight - lastScrollHeight;
    
    if (isNearBottom) {
        chatDisplay.scrollTop = newScrollHeight;
    } else {
        chatDisplay.scrollTop = lastScrollTop + scrollDifference;
    }

    lastScrollHeight = newScrollHeight;
    lastScrollTop = chatDisplay.scrollTop;
    
    isNearBottom = (chatDisplay.scrollTop + chatDisplay.clientHeight >= chatDisplay.scrollHeight - 50);
}

// Set up a MutationObserver to watch for changes in the chat display
const observer = new MutationObserver(updateScroll);
const config = { childList: true, subtree: true };

// Start observing the chat display for configured mutations
document.addEventListener('DOMContentLoaded', (event) => {
    const chatDisplay = document.getElementById('chat-display');
    if (chatDisplay) {
        observer.observe(chatDisplay, config);
        
        // Update scroll state on manual scroll
        chatDisplay.addEventListener('scroll', function() {
            lastScrollTop = chatDisplay.scrollTop;
            isNearBottom = (chatDisplay.scrollTop + chatDisplay.clientHeight >= chatDisplay.scrollHeight - 50);
        });
    }
});
"""

def user(user_message, history, character, api_key):
    if user_message.strip() == "":
        return "", history, format_history(history)
    history.append([user_message, None])
    formatted_history = format_history(history)
    return "", history, formatted_history

def bot(history, character, api_key):
    if not history:
        return history, format_history(history)
    user_message = history[-1][0]
    bot_message_generator = predict(user_message, history[:-1], character, api_key)
    for chunk in bot_message_generator:
        history[-1][1] = chunk
        yield history, format_history(history)

with gr.Blocks(css=css, js=js) as demo:
    gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>My Chatbot</h1>")
    
    chat_history = gr.State([])
    chat_display = gr.HTML(elem_id="chat-display")
    with gr.Row():
        msg = gr.Textbox(
            label="Your message", 
            lines=1,
            placeholder="Type your message here... (Press Enter to send)"
        )
        send_btn = gr.Button("Send")
    clear = gr.Button("Clear")

    dropdown = gr.Dropdown(
        ["Character 1", "Character 2", "Character 3", "Character 4", "Character 5", "Character 6", "Character 7", "Character 8", "Character 9", "Character 10", "Character 11", "Character 12", "Character 13"],
        label="Characters",
        info="Select the character that you'd like to speak to",
        value="Character 1"
    )
    api_key = gr.Textbox(type="password", label="OpenAI API Key")

    def send_message(user_message, history, character, api_key):
        return user(user_message, history, character, api_key)

    send_btn.click(send_message, [msg, chat_history, dropdown, api_key], [msg, chat_history, chat_display]).then(
        bot, [chat_history, dropdown, api_key], [chat_history, chat_display]
    )
    msg.submit(send_message, [msg, chat_history, dropdown, api_key], [msg, chat_history, chat_display]).then(
        bot, [chat_history, dropdown, api_key], [chat_history, chat_display]
    )
    clear.click(lambda: ([], []), None, [chat_history, chat_display], queue=False)
    dropdown.change(lambda x: ([], []), dropdown, [chat_history, chat_display])

demo.queue()
demo.launch(max_threads=20)