Spaces:
Runtime error
Runtime error
Merge branch 'main' of github.com:zetavg/llama-lora
Browse files
llama_lora/ui/inference_ui.py
CHANGED
@@ -104,11 +104,12 @@ def do_inference(
|
|
104 |
model = get_model(base_model_name, lora_model_name)
|
105 |
|
106 |
generation_config = GenerationConfig(
|
107 |
-
temperature=temperature,
|
108 |
top_p=top_p,
|
109 |
top_k=top_k,
|
110 |
repetition_penalty=repetition_penalty,
|
111 |
num_beams=num_beams,
|
|
|
112 |
)
|
113 |
|
114 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
@@ -325,7 +326,7 @@ def inference_ui():
|
|
325 |
# with gr.Column():
|
326 |
with gr.Accordion("Options", open=True, elem_id="inference_options_accordion"):
|
327 |
temperature = gr.Slider(
|
328 |
-
minimum=0, maximum=
|
329 |
label="Temperature",
|
330 |
elem_id="inference_temperature"
|
331 |
)
|
@@ -344,7 +345,7 @@ def inference_ui():
|
|
344 |
)
|
345 |
|
346 |
num_beams = gr.Slider(
|
347 |
-
minimum=1, maximum=5, value=
|
348 |
label="Beams",
|
349 |
elem_id="inference_beams"
|
350 |
)
|
|
|
104 |
model = get_model(base_model_name, lora_model_name)
|
105 |
|
106 |
generation_config = GenerationConfig(
|
107 |
+
temperature=float(temperature), # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
|
108 |
top_p=top_p,
|
109 |
top_k=top_k,
|
110 |
repetition_penalty=repetition_penalty,
|
111 |
num_beams=num_beams,
|
112 |
+
do_sample=temperature > 0, # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
|
113 |
)
|
114 |
|
115 |
def ui_generation_stopping_criteria(input_ids, score, **kwargs):
|
|
|
326 |
# with gr.Column():
|
327 |
with gr.Accordion("Options", open=True, elem_id="inference_options_accordion"):
|
328 |
temperature = gr.Slider(
|
329 |
+
minimum=0, maximum=2, value=0.1, step=0.01,
|
330 |
label="Temperature",
|
331 |
elem_id="inference_temperature"
|
332 |
)
|
|
|
345 |
)
|
346 |
|
347 |
num_beams = gr.Slider(
|
348 |
+
minimum=1, maximum=5, value=1, step=1,
|
349 |
label="Beams",
|
350 |
elem_id="inference_beams"
|
351 |
)
|