Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
•
517781a
1
Parent(s):
fb9b56d
auto select prompt template for model
Browse files- llama_lora/ui/inference_ui.py +57 -13
- llama_lora/ui/main_page.py +41 -2
- llama_lora/utils/data.py +16 -0
llama_lora/ui/inference_ui.py
CHANGED
@@ -11,7 +11,8 @@ from ..models import get_base_model, get_model_with_lora, get_tokenizer, get_dev
|
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
-
get_path_of_available_lora_model
|
|
|
15 |
from ..utils.prompter import Prompter
|
16 |
from ..utils.callbacks import Iteratorize, Stream
|
17 |
|
@@ -41,7 +42,9 @@ def do_inference(
|
|
41 |
prompter = Prompter(prompt_template)
|
42 |
prompt = prompter.generate_prompt(variables)
|
43 |
|
44 |
-
if
|
|
|
|
|
45 |
path_of_available_lora_model = get_path_of_available_lora_model(
|
46 |
lora_model_name)
|
47 |
if path_of_available_lora_model:
|
@@ -75,7 +78,7 @@ def do_inference(
|
|
75 |
return
|
76 |
|
77 |
model = get_base_model()
|
78 |
-
if
|
79 |
model = get_model_with_lora(lora_model_name)
|
80 |
tokenizer = get_tokenizer()
|
81 |
|
@@ -172,7 +175,7 @@ def reload_selections(current_lora_model, current_prompt_template):
|
|
172 |
gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
|
173 |
|
174 |
|
175 |
-
def handle_prompt_template_change(prompt_template):
|
176 |
prompter = Prompter(prompt_template)
|
177 |
var_names = prompter.get_variable_names()
|
178 |
human_var_names = [' '.join(word.capitalize()
|
@@ -182,7 +185,35 @@ def handle_prompt_template_change(prompt_template):
|
|
182 |
while len(gr_updates) < 8:
|
183 |
gr_updates.append(gr.Textbox.update(
|
184 |
label="Not Used", visible=False))
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
def update_prompt_preview(prompt_template,
|
@@ -200,12 +231,15 @@ def inference_ui():
|
|
200 |
|
201 |
with gr.Blocks() as inference_ui_blocks:
|
202 |
with gr.Row():
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
209 |
prompt_template = gr.Dropdown(
|
210 |
label="Prompt Template",
|
211 |
elem_id="inference_prompt_template",
|
@@ -346,10 +380,20 @@ def inference_ui():
|
|
346 |
)
|
347 |
things_that_might_timeout.append(reload_selections_event)
|
348 |
|
349 |
-
prompt_template_change_event = prompt_template.change(
|
350 |
-
|
|
|
|
|
|
|
|
|
351 |
things_that_might_timeout.append(prompt_template_change_event)
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
generate_event = generate_btn.click(
|
354 |
fn=do_inference,
|
355 |
inputs=[
|
|
|
11 |
from ..utils.data import (
|
12 |
get_available_template_names,
|
13 |
get_available_lora_model_names,
|
14 |
+
get_path_of_available_lora_model,
|
15 |
+
get_info_of_available_lora_model)
|
16 |
from ..utils.prompter import Prompter
|
17 |
from ..utils.callbacks import Iteratorize, Stream
|
18 |
|
|
|
42 |
prompter = Prompter(prompt_template)
|
43 |
prompt = prompter.generate_prompt(variables)
|
44 |
|
45 |
+
if not lora_model_name:
|
46 |
+
lora_model_name = "None"
|
47 |
+
if "/" not in lora_model_name and lora_model_name != "None":
|
48 |
path_of_available_lora_model = get_path_of_available_lora_model(
|
49 |
lora_model_name)
|
50 |
if path_of_available_lora_model:
|
|
|
78 |
return
|
79 |
|
80 |
model = get_base_model()
|
81 |
+
if lora_model_name != "None":
|
82 |
model = get_model_with_lora(lora_model_name)
|
83 |
tokenizer = get_tokenizer()
|
84 |
|
|
|
175 |
gr.Dropdown.update(choices=available_template_names_with_none, value=current_prompt_template))
|
176 |
|
177 |
|
178 |
+
def handle_prompt_template_change(prompt_template, lora_model):
|
179 |
prompter = Prompter(prompt_template)
|
180 |
var_names = prompter.get_variable_names()
|
181 |
human_var_names = [' '.join(word.capitalize()
|
|
|
185 |
while len(gr_updates) < 8:
|
186 |
gr_updates.append(gr.Textbox.update(
|
187 |
label="Not Used", visible=False))
|
188 |
+
|
189 |
+
model_prompt_template_message_update = gr.Markdown.update("", visible=False)
|
190 |
+
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
191 |
+
if lora_mode_info and isinstance(lora_mode_info, dict):
|
192 |
+
model_prompt_template = lora_mode_info.get("prompt_template")
|
193 |
+
if model_prompt_template and model_prompt_template != prompt_template:
|
194 |
+
model_prompt_template_message_update = gr.Markdown.update(
|
195 |
+
f"Trained with prompt template `{model_prompt_template}`", visible=True)
|
196 |
+
|
197 |
+
return [model_prompt_template_message_update] + gr_updates
|
198 |
+
|
199 |
+
|
200 |
+
def handle_lora_model_change(lora_model, prompt_template):
|
201 |
+
lora_mode_info = get_info_of_available_lora_model(lora_model)
|
202 |
+
if not lora_mode_info:
|
203 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
204 |
+
|
205 |
+
if not isinstance(lora_mode_info, dict):
|
206 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
207 |
+
|
208 |
+
model_prompt_template = lora_mode_info.get("prompt_template")
|
209 |
+
if not model_prompt_template:
|
210 |
+
return gr.Markdown.update("", visible=False), prompt_template
|
211 |
+
|
212 |
+
available_template_names = get_available_template_names()
|
213 |
+
if model_prompt_template in available_template_names:
|
214 |
+
return gr.Markdown.update("", visible=False), model_prompt_template
|
215 |
+
|
216 |
+
return gr.Markdown.update(f"Trained with prompt template `{model_prompt_template}`", visible=True), prompt_template
|
217 |
|
218 |
|
219 |
def update_prompt_preview(prompt_template,
|
|
|
231 |
|
232 |
with gr.Blocks() as inference_ui_blocks:
|
233 |
with gr.Row():
|
234 |
+
with gr.Column(elem_id="inference_lora_model_group"):
|
235 |
+
model_prompt_template_message = gr.Markdown(
|
236 |
+
"", visible=False, elem_id="inference_lora_model_prompt_template_message")
|
237 |
+
lora_model = gr.Dropdown(
|
238 |
+
label="LoRA Model",
|
239 |
+
elem_id="inference_lora_model",
|
240 |
+
value="tloen/alpaca-lora-7b",
|
241 |
+
allow_custom_value=True,
|
242 |
+
)
|
243 |
prompt_template = gr.Dropdown(
|
244 |
label="Prompt Template",
|
245 |
elem_id="inference_prompt_template",
|
|
|
380 |
)
|
381 |
things_that_might_timeout.append(reload_selections_event)
|
382 |
|
383 |
+
prompt_template_change_event = prompt_template.change(
|
384 |
+
fn=handle_prompt_template_change,
|
385 |
+
inputs=[prompt_template, lora_model],
|
386 |
+
outputs=[
|
387 |
+
model_prompt_template_message,
|
388 |
+
variable_0, variable_1, variable_2, variable_3, variable_4, variable_5, variable_6, variable_7])
|
389 |
things_that_might_timeout.append(prompt_template_change_event)
|
390 |
|
391 |
+
lora_model_change_event = lora_model.change(
|
392 |
+
fn=handle_lora_model_change,
|
393 |
+
inputs=[lora_model, prompt_template],
|
394 |
+
outputs=[model_prompt_template_message, prompt_template])
|
395 |
+
things_that_might_timeout.append(lora_model_change_event)
|
396 |
+
|
397 |
generate_event = generate_btn.click(
|
398 |
fn=do_inference,
|
399 |
inputs=[
|
llama_lora/ui/main_page.py
CHANGED
@@ -134,6 +134,41 @@ def main_page_custom_css():
|
|
134 |
/* text-transform: uppercase; */
|
135 |
}
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
#inference_prompt_box > *:first-child {
|
138 |
border-bottom-left-radius: 0;
|
139 |
border-bottom-right-radius: 0;
|
@@ -266,12 +301,16 @@ def main_page_custom_css():
|
|
266 |
}
|
267 |
|
268 |
@media screen and (min-width: 640px) {
|
269 |
-
#inference_lora_model, #
|
|
|
270 |
border-top-right-radius: 0;
|
271 |
border-bottom-right-radius: 0;
|
272 |
border-right: 0;
|
273 |
margin-right: -16px;
|
274 |
}
|
|
|
|
|
|
|
275 |
|
276 |
#inference_prompt_template {
|
277 |
border-top-left-radius: 0;
|
@@ -301,7 +340,7 @@ def main_page_custom_css():
|
|
301 |
height: 42px !important;
|
302 |
min-width: 42px !important;
|
303 |
width: 42px !important;
|
304 |
-
z-index:
|
305 |
}
|
306 |
}
|
307 |
|
|
|
134 |
/* text-transform: uppercase; */
|
135 |
}
|
136 |
|
137 |
+
#inference_lora_model_group {
|
138 |
+
border-radius: var(--block-radius);
|
139 |
+
background: var(--block-background-fill);
|
140 |
+
}
|
141 |
+
#inference_lora_model_group #inference_lora_model {
|
142 |
+
background: transparent;
|
143 |
+
}
|
144 |
+
#inference_lora_model_prompt_template_message:not(.hidden) + #inference_lora_model {
|
145 |
+
padding-bottom: 28px;
|
146 |
+
}
|
147 |
+
#inference_lora_model_group > #inference_lora_model_prompt_template_message {
|
148 |
+
position: absolute;
|
149 |
+
bottom: 8px;
|
150 |
+
left: 20px;
|
151 |
+
z-index: 1;
|
152 |
+
font-size: 12px;
|
153 |
+
opacity: 0.7;
|
154 |
+
}
|
155 |
+
#inference_lora_model_group > #inference_lora_model_prompt_template_message p {
|
156 |
+
font-size: 12px;
|
157 |
+
}
|
158 |
+
#inference_lora_model_prompt_template_message > .wrap {
|
159 |
+
display: none;
|
160 |
+
}
|
161 |
+
#inference_lora_model > .wrap:first-child:not(.hide),
|
162 |
+
#inference_prompt_template > .wrap:first-child:not(.hide) {
|
163 |
+
opacity: 0.5;
|
164 |
+
}
|
165 |
+
#inference_lora_model_group, #inference_lora_model {
|
166 |
+
z-index: 60;
|
167 |
+
}
|
168 |
+
#inference_prompt_template {
|
169 |
+
z-index: 55;
|
170 |
+
}
|
171 |
+
|
172 |
#inference_prompt_box > *:first-child {
|
173 |
border-bottom-left-radius: 0;
|
174 |
border-bottom-right-radius: 0;
|
|
|
301 |
}
|
302 |
|
303 |
@media screen and (min-width: 640px) {
|
304 |
+
#inference_lora_model, #inference_lora_model_group,
|
305 |
+
#finetune_template {
|
306 |
border-top-right-radius: 0;
|
307 |
border-bottom-right-radius: 0;
|
308 |
border-right: 0;
|
309 |
margin-right: -16px;
|
310 |
}
|
311 |
+
#inference_lora_model_group #inference_lora_model {
|
312 |
+
box-shadow: var(--block-shadow);
|
313 |
+
}
|
314 |
|
315 |
#inference_prompt_template {
|
316 |
border-top-left-radius: 0;
|
|
|
340 |
height: 42px !important;
|
341 |
min-width: 42px !important;
|
342 |
width: 42px !important;
|
343 |
+
z-index: 61;
|
344 |
}
|
345 |
}
|
346 |
|
llama_lora/utils/data.py
CHANGED
@@ -52,6 +52,22 @@ def get_path_of_available_lora_model(name):
|
|
52 |
return None
|
53 |
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
def get_dataset_content(name):
|
56 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
57 |
if not os.path.exists(file_name):
|
|
|
52 |
return None
|
53 |
|
54 |
|
55 |
+
def get_info_of_available_lora_model(name):
|
56 |
+
try:
|
57 |
+
if "/" in name:
|
58 |
+
return None
|
59 |
+
path_of_available_lora_model = get_path_of_available_lora_model(
|
60 |
+
name)
|
61 |
+
if not path_of_available_lora_model:
|
62 |
+
return None
|
63 |
+
|
64 |
+
with open(os.path.join(path_of_available_lora_model, "info.json"), "r") as json_file:
|
65 |
+
return json.load(json_file)
|
66 |
+
|
67 |
+
except Exception as e:
|
68 |
+
return None
|
69 |
+
|
70 |
+
|
71 |
def get_dataset_content(name):
|
72 |
file_name = os.path.join(Global.data_dir, "datasets", name)
|
73 |
if not os.path.exists(file_name):
|