liuyizhang commited on
Commit
459631c
β€’
1 Parent(s): 5645dcc

add system prompt

Browse files
Files changed (1) hide show
  1. app.py +122 -20
app.py CHANGED
@@ -5,7 +5,7 @@ import gradio as gr
5
 
6
  from loguru import logger
7
  import paddlehub as hub
8
- import random
9
  from encoder import get_encoder
10
 
11
  openai.api_key = os.getenv("OPENAI_API_KEY")
@@ -96,6 +96,10 @@ start_work = """async() => {
96
  img.style.height = img.offsetWidth + 'px';
97
  }
98
  function load_conversation(chatbot) {
 
 
 
 
99
  var json_str = localStorage.getItem('chatgpt_conversations');
100
  if (json_str) {
101
  var conversations_clear = new Array();
@@ -187,6 +191,7 @@ start_work = """async() => {
187
  window['chat_bot1'].style.height = new_height;
188
  window['chat_bot1'].children[1].style.height = new_height;
189
  window['chat_bot1'].children[0].style.top = (parseInt(window['chat_bot1'].style.height)-window['chat_bot1'].children[0].offsetHeight-2) + 'px';
 
190
  prompt_row.children[0].style.flex = 'auto';
191
  prompt_row.children[0].style.width = '100%';
192
  window['gradioEl'].querySelectorAll('#chat_radio')[0].style.flex = 'auto';
@@ -195,7 +200,8 @@ start_work = """async() => {
195
  window['chat_bot1'].children[1].setAttribute('style', 'border-bottom-right-radius:0;top:unset;bottom:0;padding-left:0.1rem');
196
  window['gradioEl'].querySelectorAll('#btns_row')[0].children[0].setAttribute('style', 'min-width: min(10px, 100%); flex-grow: 1');
197
  window['gradioEl'].querySelectorAll('#btns_row')[0].children[1].setAttribute('style', 'min-width: min(10px, 100%); flex-grow: 1');
198
-
 
199
  load_conversation(window['chat_bot1'].children[1].children[0]);
200
  window['chat_bot1'].children[1].scrollTop = window['chat_bot1'].children[1].scrollHeight;
201
 
@@ -233,6 +239,7 @@ start_work = """async() => {
233
  try {
234
  if (window['chat_radio_0'].checked) {
235
  dot_flashing = window['chat_bot'].children[1].children[0].querySelectorAll('.dot-flashing');
 
236
  if (window['chat_bot'].children[1].children[0].children.length > window['div_count'] && dot_flashing.length == 0) {
237
  new_len = window['chat_bot'].children[1].children[0].children.length - window['div_count'];
238
  for (var i = 0; i < new_len; i++) {
@@ -254,6 +261,7 @@ start_work = """async() => {
254
  img_index = 0;
255
  draw_prompt_en = window['my_prompt_en'].value;
256
  if (window['doCheckPrompt'] == 0 && window['prevPrompt'] != draw_prompt_en) {
 
257
  console.log('_____draw_prompt_en___[' + draw_prompt_en + ']_');
258
  window['doCheckPrompt'] = 1;
259
  window['prevPrompt'] = draw_prompt_en;
@@ -284,6 +292,7 @@ start_work = """async() => {
284
  user_div.dataset.testid = 'user';
285
  user_div.innerHTML = "<p>δ½œη”»: " + window['draw_prompt'] + "</p><img></img>";
286
  window['chat_bot1'].children[1].children[0].appendChild(user_div);
 
287
  var bot_div = document.createElement("div");
288
  bot_div.className = "message bot svelte-134zwfa";
289
  bot_div.style.backgroundColor = "#2563eb";
@@ -318,7 +327,7 @@ start_work = """async() => {
318
  window['chat_bot1'].children[0].textContent = '';
319
  }
320
  }
321
-
322
  } catch(e) {
323
  }
324
  }
@@ -355,16 +364,19 @@ def set_openai_api_key(api_key):
355
  if api_key and api_key.startswith("sk-") and len(api_key) > 50:
356
  openai.api_key = api_key
357
 
358
- def get_response_from_openai(input, chat_history, model_radio):
359
  error_1 = 'You exceeded your current quota, please check your plan and billing details.'
360
- def openai_create(input_list, model_radio):
361
  try:
362
  # print(f'input_list={input_list}')
363
  input_list_len = len(input_list)
364
  out_prompt = ''
365
  messages = []
 
366
  if model_radio == 'GPT-3.0':
367
  out_prompt = 'AI:'
 
 
368
  for i in range(input_list_len):
369
  input = input_list[input_list_len-i-1].replace("<br>", '\n\n')
370
  if input.startswith("Openai said:"):
@@ -425,6 +437,94 @@ def get_response_from_openai(input, chat_history, model_radio):
425
  ret = f"Openai said: {e} Perhaps enter your OpenAI API key."
426
  return ret, {"completion_tokens": -1, "prompt_tokens": -1, "total_tokens": -1}
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  # logger.info(f'chat_history = {chat_history}')
429
  chat_history_list = []
430
  chat_history = chat_history.replace("<p>", "").replace("</p>", "")
@@ -432,11 +532,11 @@ def get_response_from_openai(input, chat_history, model_radio):
432
  chat_history_list = json.loads(chat_history)
433
  chat_history_list.append(f'☟:{input}')
434
 
435
- output, response_usage = openai_create(chat_history_list, model_radio)
436
  logger.info(f'response_usage={response_usage}')
437
  return output
438
 
439
- def chat(input0, input1, chat_radio, model_radio, all_chat_history, chat_history):
440
  all_chat = []
441
  if all_chat_history != '':
442
  all_chat = json.loads(all_chat_history)
@@ -445,7 +545,7 @@ def chat(input0, input1, chat_radio, model_radio, all_chat_history, chat_history
445
  return all_chat, json.dumps(all_chat), input0, input1
446
 
447
  if chat_radio == "Talk to chatGPT":
448
- response = get_response_from_openai(input0, chat_history, model_radio)
449
  all_chat.append((input0, response))
450
  return all_chat, json.dumps(all_chat), '', input1
451
  else:
@@ -454,9 +554,9 @@ def chat(input0, input1, chat_radio, model_radio, all_chat_history, chat_history
454
 
455
  def chat_radio_change(chat_radio):
456
  if chat_radio == "Talk to chatGPT":
457
- return gr.Radio.update(visible=True), gr.Text.update(visible=True)
458
  else:
459
- return gr.Radio.update(visible=False), gr.Text.update(visible=False)
460
 
461
  with gr.Blocks(title='Talk to chatGPT') as demo:
462
  with gr.Row(elem_id="page_0", visible=False) as page_0:
@@ -465,20 +565,21 @@ with gr.Blocks(title='Talk to chatGPT') as demo:
465
  with gr.Box():
466
  with gr.Row():
467
  start_button = gr.Button("Let's talk to chatGPT!", elem_id="start-btn", visible=True)
468
- start_button.click(fn=None, inputs=[], outputs=[], _js=start_work)
469
-
470
  with gr.Row(elem_id="page_2", visible=False) as page_2:
471
  with gr.Row(elem_id="chat_row"):
472
  chatbot = gr.Chatbot(elem_id="chat_bot", visible=False).style(color_map=("green", "blue"))
473
  chatbot1 = gr.Chatbot(elem_id="chat_bot1").style(color_map=("green", "blue"))
 
 
474
  with gr.Row(elem_id="prompt_row"):
475
  prompt_input0 = gr.Textbox(lines=2, label="input", elem_id="my_prompt", show_label=True)
476
  prompt_input1 = gr.Textbox(lines=4, label="prompt", elem_id="my_prompt_en", visible=False)
477
  chat_history = gr.Textbox(lines=4, label="chat_history", elem_id="chat_history", visible=False)
478
- all_chat_history = gr.Textbox(lines=4, label="δΌšθ―δΈŠδΈ‹ζ–‡οΌš", elem_id="all_chat_history", visible=False)
479
 
480
  chat_radio = gr.Radio(["Talk to chatGPT", "Text to Image"], elem_id="chat_radio",value="Talk to chatGPT", show_label=False, visible=True)
481
- model_radio = gr.Radio(["GPT-3.0", "GPT-3.5"], elem_id="model_radio", value="GPT-3.5",
482
  label='GPT model: ', show_label=True,interactive=True, visible=True)
483
  openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter",
484
  show_label=False, lines=1, type='password')
@@ -495,13 +596,14 @@ with gr.Blocks(title='Talk to chatGPT') as demo:
495
  rounded=(True, True, True, True),
496
  width=100
497
  )
498
- submit_btn.click(fn=chat,
499
- inputs=[prompt_input0, prompt_input1, chat_radio, model_radio, all_chat_history, chat_history],
500
- outputs=[chatbot, all_chat_history, prompt_input0, prompt_input1],
501
- )
502
  with gr.Row(elem_id='tab_img', visible=False).style(height=5):
503
  tab_img = gr.TabbedInterface(tab_actions, tab_titles)
504
 
 
 
 
 
505
  openai_api_key_textbox.change(set_openai_api_key,
506
  inputs=[openai_api_key_textbox],
507
  outputs=[])
@@ -510,7 +612,7 @@ with gr.Blocks(title='Talk to chatGPT') as demo:
510
  outputs=[])
511
  chat_radio.change(fn=chat_radio_change,
512
  inputs=[chat_radio],
513
- outputs=[model_radio, openai_api_key_textbox],
514
  )
515
 
516
- demo.launch(debug = True)
 
5
 
6
  from loguru import logger
7
  import paddlehub as hub
8
+ import random, re
9
  from encoder import get_encoder
10
 
11
  openai.api_key = os.getenv("OPENAI_API_KEY")
 
96
  img.style.height = img.offsetWidth + 'px';
97
  }
98
  function load_conversation(chatbot) {
99
+ var my_prompt_system_value = localStorage.getItem('my_prompt_system');
100
+ if (typeof my_prompt_system_value !== 'undefined') {
101
+ setNativeValue(window['my_prompt_system'], my_prompt_system_value);
102
+ }
103
  var json_str = localStorage.getItem('chatgpt_conversations');
104
  if (json_str) {
105
  var conversations_clear = new Array();
 
191
  window['chat_bot1'].style.height = new_height;
192
  window['chat_bot1'].children[1].style.height = new_height;
193
  window['chat_bot1'].children[0].style.top = (parseInt(window['chat_bot1'].style.height)-window['chat_bot1'].children[0].offsetHeight-2) + 'px';
194
+
195
  prompt_row.children[0].style.flex = 'auto';
196
  prompt_row.children[0].style.width = '100%';
197
  window['gradioEl'].querySelectorAll('#chat_radio')[0].style.flex = 'auto';
 
200
  window['chat_bot1'].children[1].setAttribute('style', 'border-bottom-right-radius:0;top:unset;bottom:0;padding-left:0.1rem');
201
  window['gradioEl'].querySelectorAll('#btns_row')[0].children[0].setAttribute('style', 'min-width: min(10px, 100%); flex-grow: 1');
202
  window['gradioEl'].querySelectorAll('#btns_row')[0].children[1].setAttribute('style', 'min-width: min(10px, 100%); flex-grow: 1');
203
+ window['my_prompt_system'] = window['gradioEl'].querySelectorAll('#my_prompt_system')[0].querySelectorAll('textarea')[0];
204
+
205
  load_conversation(window['chat_bot1'].children[1].children[0]);
206
  window['chat_bot1'].children[1].scrollTop = window['chat_bot1'].children[1].scrollHeight;
207
 
 
239
  try {
240
  if (window['chat_radio_0'].checked) {
241
  dot_flashing = window['chat_bot'].children[1].children[0].querySelectorAll('.dot-flashing');
242
+
243
  if (window['chat_bot'].children[1].children[0].children.length > window['div_count'] && dot_flashing.length == 0) {
244
  new_len = window['chat_bot'].children[1].children[0].children.length - window['div_count'];
245
  for (var i = 0; i < new_len; i++) {
 
261
  img_index = 0;
262
  draw_prompt_en = window['my_prompt_en'].value;
263
  if (window['doCheckPrompt'] == 0 && window['prevPrompt'] != draw_prompt_en) {
264
+
265
  console.log('_____draw_prompt_en___[' + draw_prompt_en + ']_');
266
  window['doCheckPrompt'] = 1;
267
  window['prevPrompt'] = draw_prompt_en;
 
292
  user_div.dataset.testid = 'user';
293
  user_div.innerHTML = "<p>δ½œη”»: " + window['draw_prompt'] + "</p><img></img>";
294
  window['chat_bot1'].children[1].children[0].appendChild(user_div);
295
+
296
  var bot_div = document.createElement("div");
297
  bot_div.className = "message bot svelte-134zwfa";
298
  bot_div.style.backgroundColor = "#2563eb";
 
327
  window['chat_bot1'].children[0].textContent = '';
328
  }
329
  }
330
+ localStorage.setItem('my_prompt_system', window['my_prompt_system'].value);
331
  } catch(e) {
332
  }
333
  }
 
364
  if api_key and api_key.startswith("sk-") and len(api_key) > 50:
365
  openai.api_key = api_key
366
 
367
+ def get_response_from_openai(input, prompt_system, chat_history, model_radio, temperature=0.0):
368
  error_1 = 'You exceeded your current quota, please check your plan and billing details.'
369
+ def openai_create_old(input_list, prompt_system, model_radio):
370
  try:
371
  # print(f'input_list={input_list}')
372
  input_list_len = len(input_list)
373
  out_prompt = ''
374
  messages = []
375
+ prompt_system_token = prompt_system
376
  if model_radio == 'GPT-3.0':
377
  out_prompt = 'AI:'
378
+ if prompt_system != '':
379
+ prompt_system_token = f'Human:{prompt_system}'
380
  for i in range(input_list_len):
381
  input = input_list[input_list_len-i-1].replace("<br>", '\n\n')
382
  if input.startswith("Openai said:"):
 
437
  ret = f"Openai said: {e} Perhaps enter your OpenAI API key."
438
  return ret, {"completion_tokens": -1, "prompt_tokens": -1, "total_tokens": -1}
439
 
440
+ def openai_create(input_list, prompt_system, model_radio):
441
+ try:
442
+ input_list_len = len(input_list)
443
+ out_prompt = ''
444
+ messages = []
445
+ prompt_system_token = prompt_system
446
+ if model_radio in ['GPT-3.0']:
447
+ out_prompt += 'AI:'
448
+ if prompt_system != '':
449
+ prompt_system_token = f'Human:{prompt_system}'
450
+ for i in range(input_list_len):
451
+ input = input_list[input_list_len-i-1].replace("<br>", '\n\n')
452
+ if input.startswith("Openai said:"):
453
+ input = "☝:"
454
+ if input.startswith("☝:"):
455
+ if model_radio in ['GPT-3.0']:
456
+ out_prompt = input.replace("☝:", "AI:") + '\n' + out_prompt
457
+ else:
458
+ out_prompt = input.replace("☝:", "") + out_prompt
459
+ messages.insert(0, {"role": "assistant", "content": input.replace("☝:", "")})
460
+ elif input.startswith("☟:"):
461
+ if model_radio in ['GPT-3.0']:
462
+ out_prompt = input.replace("☟:", "Human:") + '\n' + out_prompt
463
+ else:
464
+ out_prompt = input.replace("☟:", "") + out_prompt
465
+ messages.insert(0, {"role": "user", "content": input.replace("☟:", "")})
466
+ tokens = token_encoder.encode(out_prompt + prompt_system_token)
467
+ if model_radio in ['GPT-4.0']:
468
+ if len(tokens) > max_input_tokens + total_tokens:
469
+ break
470
+ else:
471
+ if len(tokens) > max_input_tokens:
472
+ break
473
+
474
+ if prompt_system != '':
475
+ if model_radio in ['GPT-3.0']:
476
+ out_prompt = prompt_system_token + out_prompt
477
+ else:
478
+ out_prompt += prompt_system
479
+ messages.insert(0, {"role": "system", "content": prompt_system})
480
+
481
+ if model_radio in ['GPT-3.0']:
482
+ print(f'response_3.0_out_prompt__:{out_prompt}')
483
+ response = openai.Completion.create(
484
+ model="text-davinci-003",
485
+ prompt=out_prompt,
486
+ temperature=temperature,
487
+ max_tokens=max_output_tokens,
488
+ top_p=1,
489
+ frequency_penalty=0,
490
+ presence_penalty=0,
491
+ stop=[" Human:", " AI:"]
492
+ )
493
+ print(f'response_3.0_response__:{response}')
494
+ ret = response.choices[0].text
495
+ else:
496
+ print(f'response_{model_radio}_messages__:{messages}')
497
+ if model_radio in ['GPT-3.5']:
498
+ model_name = "gpt-3.5-turbo"
499
+ else:
500
+ model_name = "gpt-4"
501
+ response = openai.ChatCompletion.create(
502
+ model=model_name,
503
+ messages=messages,
504
+ temperature=temperature,
505
+ max_tokens=max_output_tokens,
506
+ top_p=1,
507
+ frequency_penalty=0,
508
+ presence_penalty=0,
509
+ stop=[" Human:", " AI:"]
510
+ )
511
+ print(f'response_{model_radio}_response__:{response}')
512
+ ret = response.choices[0].message['content']
513
+
514
+ if ret.startswith("\n\n"):
515
+ ret = ret.replace("\n\n", '')
516
+ ret = ret.replace('\n', '<br>')
517
+ ret = re.sub(r"```(.*?)```", r"<pre style='background-color: #ddd;margin-left:20px;padding-left:10px;'>\1</pre>", ret)
518
+ ret = re.sub(r"<pre style='background-color: #ddd;margin-left:20px;padding-left:10px;'>(.*?)</pre>", lambda m: "<pre style='background-color: #ddd; color:#000; margin-left:20px;padding-left:10px;'>" + m.group(1).replace('<br>', '\n') + '</pre>', ret)
519
+ # logger.info(f'ret_2_={ret}')
520
+ if ret == '':
521
+ ret = f"Openai said: I'm too tired."
522
+ return ret, response.usage
523
+ except Exception as e:
524
+ logger.info(f"openai_create_error__{e}")
525
+ ret = f"Openai said: {e} Perhaps enter your OpenAI API key."
526
+ return ret, {"completion_tokens": -1, "prompt_tokens": -1, "total_tokens": -1}
527
+
528
  # logger.info(f'chat_history = {chat_history}')
529
  chat_history_list = []
530
  chat_history = chat_history.replace("<p>", "").replace("</p>", "")
 
532
  chat_history_list = json.loads(chat_history)
533
  chat_history_list.append(f'☟:{input}')
534
 
535
+ output, response_usage = openai_create(chat_history_list, prompt_system, model_radio)
536
  logger.info(f'response_usage={response_usage}')
537
  return output
538
 
539
+ def chat(input0, input1, prompt_system, chat_radio, model_radio, all_chat_history, chat_history):
540
  all_chat = []
541
  if all_chat_history != '':
542
  all_chat = json.loads(all_chat_history)
 
545
  return all_chat, json.dumps(all_chat), input0, input1
546
 
547
  if chat_radio == "Talk to chatGPT":
548
+ response = get_response_from_openai(input0, prompt_system, chat_history, model_radio)
549
  all_chat.append((input0, response))
550
  return all_chat, json.dumps(all_chat), '', input1
551
  else:
 
554
 
555
  def chat_radio_change(chat_radio):
556
  if chat_radio == "Talk to chatGPT":
557
+ return gr.Radio.update(visible=True), gr.Text.update(visible=True), gr.Textbox.update(visible=True)
558
  else:
559
+ return gr.Radio.update(visible=False), gr.Text.update(visible=False), gr.Textbox.update(visible=False)
560
 
561
  with gr.Blocks(title='Talk to chatGPT') as demo:
562
  with gr.Row(elem_id="page_0", visible=False) as page_0:
 
565
  with gr.Box():
566
  with gr.Row():
567
  start_button = gr.Button("Let's talk to chatGPT!", elem_id="start-btn", visible=True)
568
+ start_button.click(fn=None, inputs=[], outputs=[], _js=start_work)
 
569
  with gr.Row(elem_id="page_2", visible=False) as page_2:
570
  with gr.Row(elem_id="chat_row"):
571
  chatbot = gr.Chatbot(elem_id="chat_bot", visible=False).style(color_map=("green", "blue"))
572
  chatbot1 = gr.Chatbot(elem_id="chat_bot1").style(color_map=("green", "blue"))
573
+ with gr.Row(elem_id="system_prompt_row"):
574
+ prompt_system = gr.Textbox(lines=2, label="system prompt:", elem_id="my_prompt_system", visible=True, show_label=True)
575
  with gr.Row(elem_id="prompt_row"):
576
  prompt_input0 = gr.Textbox(lines=2, label="input", elem_id="my_prompt", show_label=True)
577
  prompt_input1 = gr.Textbox(lines=4, label="prompt", elem_id="my_prompt_en", visible=False)
578
  chat_history = gr.Textbox(lines=4, label="chat_history", elem_id="chat_history", visible=False)
579
+ all_chat_history = gr.Textbox(lines=4, label="contexts:", elem_id="all_chat_history", visible=False)
580
 
581
  chat_radio = gr.Radio(["Talk to chatGPT", "Text to Image"], elem_id="chat_radio",value="Talk to chatGPT", show_label=False, visible=True)
582
+ model_radio = gr.Radio(["GPT-3.0", "GPT-3.5", "GPT-4.0"], elem_id="model_radio", value="GPT-3.5",
583
  label='GPT model: ', show_label=True,interactive=True, visible=True)
584
  openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...) and hit Enter",
585
  show_label=False, lines=1, type='password')
 
596
  rounded=(True, True, True, True),
597
  width=100
598
  )
599
+
 
 
 
600
  with gr.Row(elem_id='tab_img', visible=False).style(height=5):
601
  tab_img = gr.TabbedInterface(tab_actions, tab_titles)
602
 
603
+ submit_btn.click(fn=chat,
604
+ inputs=[prompt_input0, prompt_input1, prompt_system, chat_radio, model_radio, all_chat_history, chat_history],
605
+ outputs=[chatbot, all_chat_history, prompt_input0, prompt_input1],
606
+ )
607
  openai_api_key_textbox.change(set_openai_api_key,
608
  inputs=[openai_api_key_textbox],
609
  outputs=[])
 
612
  outputs=[])
613
  chat_radio.change(fn=chat_radio_change,
614
  inputs=[chat_radio],
615
+ outputs=[model_radio, openai_api_key_textbox, prompt_system],
616
  )
617
 
618
+ demo.launch(server_name='0.0.0.0', debug = True)