File size: 5,052 Bytes
e6868fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
from openai import OpenAI
import time
import html

def predict(message, history, character, api_key, progress=gr.Progress()):
    client = OpenAI(api_key=api_key)
    history_openai_format = []
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human})
        history_openai_format.append({"role": "assistant", "content": assistant})
    history_openai_format.append({"role": "user", "content": message})

    response = client.chat.completions.create(
        model='gpt-4',
        messages=history_openai_format,
        temperature=1.0,
        stream=True
    )

    partial_message = ""
    for chunk in progress.tqdm(response, desc="Generating"):
        if chunk.choices[0].delta.content:
            partial_message += chunk.choices[0].delta.content
            yield partial_message
        time.sleep(0.01)

def format_history(history):
    html_content = ""
    for human, ai in history:
        human_formatted = html.escape(human).replace('\n', '<br>')
        html_content += f'<div class="message user-message"><strong>You:</strong> {human_formatted}</div>'
        if ai:
            ai_formatted = html.escape(ai).replace('\n', '<br>')
            html_content += f'<div class="message ai-message"><strong>AI:</strong> {ai_formatted}</div>'
    return html_content

css = """
#chat-display {
    height: 600px;
    overflow-y: auto;
    border: 1px solid #ccc;
    padding: 10px;
    margin-bottom: 10px;
}
#chat-display::-webkit-scrollbar {
    width: 10px;
}
#chat-display::-webkit-scrollbar-track {
    background: #f1f1f1;
}
#chat-display::-webkit-scrollbar-thumb {
    background: #888;
}
#chat-display::-webkit-scrollbar-thumb:hover {
    background: #555;
}
.message {
    margin-bottom: 10px;
    max-height: 300px;
    overflow-y: auto;
    word-wrap: break-word;
}
.user-message {
    background-color: #e6f3ff;
    padding: 5px;
    border-radius: 5px;
}
.ai-message {
    background-color: #f0f0f0;
    padding: 5px;
    border-radius: 5px;
}
"""

js = """
function maintainScroll(element_id) {
    let element = document.getElementById(element_id);
    let shouldScroll = element.scrollTop + element.clientHeight === element.scrollHeight;
    let previousScrollTop = element.scrollTop;
    
    return function() {
        if (!shouldScroll) {
            element.scrollTop = previousScrollTop;
        } else {
            element.scrollTop = element.scrollHeight;
        }
    }
}

let scrollMaintainer = maintainScroll('chat-display');
setInterval(scrollMaintainer, 100);

// Add event listener for Ctrl+Enter and prevent default Enter behavior
document.addEventListener('DOMContentLoaded', (event) => {
    const textbox = document.querySelector('#your_message textarea');
    textbox.addEventListener('keydown', function(e) {
        if (e.ctrlKey && e.key === 'Enter') {
            e.preventDefault();
            document.querySelector('#your_message button').click();
        } else if (e.key === 'Enter' && !e.shiftKey) {
            e.preventDefault();
            const start = this.selectionStart;
            const end = this.selectionEnd;
            this.value = this.value.substring(0, start) + "\\n" + this.value.substring(end);
            this.selectionStart = this.selectionEnd = start + 1;
        }
    });
});
"""

with gr.Blocks(css=css, js=js) as demo:
    gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>My Chatbot</h1>")
    
    chat_history = gr.State([])
    chat_display = gr.HTML(elem_id="chat-display")
    msg = gr.Textbox(
        label="Your message", 
        lines=2, 
        max_lines=10, 
        placeholder="Type your message here... (Press Ctrl+Enter to send, Enter for new line)",
        elem_id="your_message"
    )
    clear = gr.Button("Clear")

    dropdown = gr.Dropdown(
        ["Character 1", "Character 2", "Character 3", "Character 4", "Character 5", "Character 6", "Character 7", "Character 8", "Character 9", "Character 10", "Character 11", "Character 12", "Character 13"],
        label="Characters",
        info="Select the character that you'd like to speak to",
        value="Character 1"
    )
    api_key = gr.Textbox(type="password", label="OpenAI API Key")

    def user(user_message, history):
        history.append([user_message, None])
        return "", history, format_history(history)

    def bot(history, character, api_key):
        user_message = history[-1][0]
        bot_message_generator = predict(user_message, history[:-1], character, api_key)
        for chunk in bot_message_generator:
            history[-1][1] = chunk
            yield history, format_history(history)

    msg.submit(user, [msg, chat_history], [msg, chat_history, chat_display]).then(
        bot, [chat_history, dropdown, api_key], [chat_history, chat_display]
    )
    clear.click(lambda: ([], []), None, [chat_history, chat_display], queue=False)
    dropdown.change(lambda x: ([], []), dropdown, [chat_history, chat_display])

demo.queue()
demo.launch(max_threads=20)