Omnibus commited on
Commit
c371eda
1 Parent(s): 7fea8b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -38
app.py CHANGED
@@ -4,7 +4,7 @@ from huggingface_hub import InferenceClient
4
  import random
5
  ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
6
 
7
- '''models=[
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
@@ -15,52 +15,36 @@ InferenceClient(models[0]),
15
  InferenceClient(models[1]),
16
  InferenceClient(models[2]),
17
  InferenceClient(models[3]),
18
- ]'''
19
-
20
-
21
- models=[
22
- "google/gemma-7b",
23
- "google/gemma-7b-it",
24
- "google/gemma-2b",
25
- "google/gemma-2b-it",
26
  ]
27
- client_z=[]
28
 
 
29
 
30
  def load_models(inp):
31
-
32
- print(type(inp))
33
- print(inp)
34
- print(models[inp])
35
- client_z.clear()
36
- client_z.append(InferenceClient(models[inp]))
37
  return gr.update(label=models[inp])
38
 
39
- VERBOSE=False
40
-
41
  def format_prompt(message, history, cust_p):
42
- prompt = ""
43
  if history:
44
- #<start_of_turn>userHow does the brain work?<end_of_turn><start_of_turn>model
45
  for user_prompt, bot_response in history:
46
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
47
- #print(prompt)
48
- prompt += f"<start_of_turn>model\n{bot_response}<end_of_turn>"
49
- #print(prompt)
50
  #prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
51
  prompt+=cust_p.replace("USER_INPUT",message)
52
  return prompt
53
 
54
- def custom_prompt(prompt):
55
- return prompt
56
-
57
-
58
  def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,cust_p):
59
  #token max=8192
 
60
  hist_len=0
61
- #client=clients[int(client_choice)-1]
62
- client=client_z[0]
63
-
64
  if not history:
65
  history = []
66
  hist_len=0
@@ -79,7 +63,7 @@ def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,
79
  generate_kwargs = dict(
80
  temperature=temp,
81
  max_new_tokens=tokens,
82
- #top_p=top_p,
83
  repetition_penalty=rep_p,
84
  do_sample=True,
85
  seed=seed,
@@ -88,7 +72,7 @@ def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,
88
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", memory[0-chat_mem:],cust_p)
89
  else:
90
  formatted_prompt = format_prompt(prompt, memory[0-chat_mem:],cust_p)
91
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
92
  output = ""
93
  for response in stream:
94
  output += response.token.text
@@ -96,10 +80,10 @@ def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,
96
  history.append((prompt,output))
97
  memory.append((prompt,output))
98
  yield history,memory
 
99
  if VERBOSE==True:
100
  print("\n######### HIST "+str(in_len))
101
  print("\n######### TOKENS "+str(tokens))
102
- #print("\n######### PROMPT "+str(len(formatted_prompt)))
103
 
104
  def get_screenshot(chat: list,height=5000,width=600,chatblock=[],theme="light",wait=3000,header=True):
105
  print(chatblock)
@@ -130,6 +114,8 @@ with gr.Blocks() as app:
130
  with gr.Column(scale=3):
131
  inp = gr.Textbox(label="Prompt")
132
  sys_inp = gr.Textbox(label="System Prompt (optional)")
 
 
133
  with gr.Row():
134
  with gr.Column(scale=2):
135
  btn = gr.Button("Chat")
@@ -138,16 +124,14 @@ with gr.Blocks() as app:
138
  stop_btn=gr.Button("Stop")
139
  clear_btn=gr.Button("Clear")
140
  client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
141
- with gr.Accordion("Prompt Format",open=False):
142
- custom_prompt=gr.Textbox(label="Prompt Format", info="For testing purposes. 'USER_INPUT' is where 'SYSTEM_PROMPT, PROMPT' will be placed", lines=5,value="<start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model")
143
  with gr.Column(scale=1):
144
  with gr.Group():
145
  rand = gr.Checkbox(label="Random Seed", value=True)
146
  seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
147
  tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
148
- temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
149
- top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.9)
150
- rep_p=gr.Slider(label="Repetition Penalty",step=0.1, minimum=0.1, maximum=2.0, value=1.0)
151
  chat_mem=gr.Number(label="Chat Memory", info="Number of previous chats to retain",value=4)
152
  with gr.Accordion(label="Screenshot",open=False):
153
  with gr.Row():
 
4
  import random
5
  ss_client = Client("https://omnibus-html-image-current-tab.hf.space/")
6
 
7
+ models=[
8
  "google/gemma-7b",
9
  "google/gemma-7b-it",
10
  "google/gemma-2b",
 
15
  InferenceClient(models[1]),
16
  InferenceClient(models[2]),
17
  InferenceClient(models[3]),
 
 
 
 
 
 
 
 
18
  ]
 
19
 
20
+ VERBOSE=False
21
 
22
  def load_models(inp):
23
+ if VERBOSE==True:
24
+ print(type(inp))
25
+ print(inp)
26
+ print(models[inp])
27
+ #client_z.clear()
28
+ #client_z.append(InferenceClient(models[inp]))
29
  return gr.update(label=models[inp])
30
 
 
 
31
  def format_prompt(message, history, cust_p):
32
+ prompt = "<bos>"
33
  if history:
 
34
  for user_prompt, bot_response in history:
35
  prompt += f"<start_of_turn>user{user_prompt}<end_of_turn>"
36
+ prompt += f"<start_of_turn>model{bot_response}<end_of_turn>"
37
+ if VERBOSE==True:
38
+ print(prompt)
39
  #prompt += f"<start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n"
40
  prompt+=cust_p.replace("USER_INPUT",message)
41
  return prompt
42
 
 
 
 
 
43
  def chat_inf(system_prompt,prompt,history,memory,client_choice,seed,temp,tokens,top_p,rep_p,chat_mem,cust_p):
44
  #token max=8192
45
+ print(client_choice)
46
  hist_len=0
47
+ client=clients[int(client_choice)-1]
 
 
48
  if not history:
49
  history = []
50
  hist_len=0
 
63
  generate_kwargs = dict(
64
  temperature=temp,
65
  max_new_tokens=tokens,
66
+ top_p=top_p,
67
  repetition_penalty=rep_p,
68
  do_sample=True,
69
  seed=seed,
 
72
  formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", memory[0-chat_mem:],cust_p)
73
  else:
74
  formatted_prompt = format_prompt(prompt, memory[0-chat_mem:],cust_p)
75
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
76
  output = ""
77
  for response in stream:
78
  output += response.token.text
 
80
  history.append((prompt,output))
81
  memory.append((prompt,output))
82
  yield history,memory
83
+
84
  if VERBOSE==True:
85
  print("\n######### HIST "+str(in_len))
86
  print("\n######### TOKENS "+str(tokens))
 
87
 
88
  def get_screenshot(chat: list,height=5000,width=600,chatblock=[],theme="light",wait=3000,header=True):
89
  print(chatblock)
 
114
  with gr.Column(scale=3):
115
  inp = gr.Textbox(label="Prompt")
116
  sys_inp = gr.Textbox(label="System Prompt (optional)")
117
+ with gr.Accordion("Prompt Format",open=False):
118
+ custom_prompt=gr.Textbox(label="Modify Prompt Format", info="For testing purposes. 'USER_INPUT' is where 'SYSTEM_PROMPT, PROMPT' will be placed", lines=3,value="<bos><start_of_turn>userUSER_INPUT<end_of_turn><start_of_turn>model")
119
  with gr.Row():
120
  with gr.Column(scale=2):
121
  btn = gr.Button("Chat")
 
124
  stop_btn=gr.Button("Stop")
125
  clear_btn=gr.Button("Clear")
126
  client_choice=gr.Dropdown(label="Models",type='index',choices=[c for c in models],value=models[0],interactive=True)
 
 
127
  with gr.Column(scale=1):
128
  with gr.Group():
129
  rand = gr.Checkbox(label="Random Seed", value=True)
130
  seed=gr.Slider(label="Seed", minimum=1, maximum=1111111111111111,step=1, value=rand_val)
131
  tokens = gr.Slider(label="Max new tokens",value=1600,minimum=0,maximum=8000,step=64,interactive=True, visible=True,info="The maximum number of tokens")
132
+ temp=gr.Slider(label="Temperature",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
133
+ top_p=gr.Slider(label="Top-P",step=0.01, minimum=0.01, maximum=1.0, value=0.49)
134
+ rep_p=gr.Slider(label="Repetition Penalty",step=0.01, minimum=0.1, maximum=2.0, value=0.99)
135
  chat_mem=gr.Number(label="Chat Memory", info="Number of previous chats to retain",value=4)
136
  with gr.Accordion(label="Screenshot",open=False):
137
  with gr.Row():