zetavg commited on
Commit
350cbe9
2 Parent(s): 9d46857 760ed6d

Merge branch 'main' of github.com:zetavg/llama-lora

Browse files
Files changed (1) hide show
  1. llama_lora/ui/inference_ui.py +4 -3
llama_lora/ui/inference_ui.py CHANGED
@@ -104,11 +104,12 @@ def do_inference(
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
  generation_config = GenerationConfig(
107
- temperature=temperature,
108
  top_p=top_p,
109
  top_k=top_k,
110
  repetition_penalty=repetition_penalty,
111
  num_beams=num_beams,
 
112
  )
113
 
114
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
@@ -325,7 +326,7 @@ def inference_ui():
325
  # with gr.Column():
326
  with gr.Accordion("Options", open=True, elem_id="inference_options_accordion"):
327
  temperature = gr.Slider(
328
- minimum=0, maximum=1, value=0.1, step=0.01,
329
  label="Temperature",
330
  elem_id="inference_temperature"
331
  )
@@ -344,7 +345,7 @@ def inference_ui():
344
  )
345
 
346
  num_beams = gr.Slider(
347
- minimum=1, maximum=5, value=2, step=1,
348
  label="Beams",
349
  elem_id="inference_beams"
350
  )
 
104
  model = get_model(base_model_name, lora_model_name)
105
 
106
  generation_config = GenerationConfig(
107
+ temperature=float(temperature), # to avoid ValueError('`temperature` has to be a strictly positive float, but is 2')
108
  top_p=top_p,
109
  top_k=top_k,
110
  repetition_penalty=repetition_penalty,
111
  num_beams=num_beams,
112
+ do_sample=temperature > 0, # https://github.com/huggingface/transformers/issues/22405#issuecomment-1485527953
113
  )
114
 
115
  def ui_generation_stopping_criteria(input_ids, score, **kwargs):
 
326
  # with gr.Column():
327
  with gr.Accordion("Options", open=True, elem_id="inference_options_accordion"):
328
  temperature = gr.Slider(
329
+ minimum=0, maximum=2, value=0.1, step=0.01,
330
  label="Temperature",
331
  elem_id="inference_temperature"
332
  )
 
345
  )
346
 
347
  num_beams = gr.Slider(
348
+ minimum=1, maximum=5, value=1, step=1,
349
  label="Beams",
350
  elem_id="inference_beams"
351
  )