zetavg commited on
Commit
600770f
1 Parent(s): 82e3afe

support more options on fine-tuning

Browse files
llama_lora/ui/finetune_ui.py CHANGED
@@ -258,13 +258,16 @@ def do_train(
258
  dataset_plain_text_data_separator,
259
  # Training Options
260
  max_seq_length,
 
261
  micro_batch_size,
262
  gradient_accumulation_steps,
263
  epochs,
264
  learning_rate,
 
265
  lora_r,
266
  lora_alpha,
267
  lora_dropout,
 
268
  model_name,
269
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
270
  ):
@@ -324,6 +327,7 @@ def do_train(
324
  data = process_json_dataset(data)
325
 
326
  data_count = len(data)
 
327
 
328
  train_data = [
329
  {
@@ -361,13 +365,16 @@ def do_train(
361
 
362
  Train options: {json.dumps({
363
  'max_seq_length': max_seq_length,
 
364
  'micro_batch_size': micro_batch_size,
365
  'gradient_accumulation_steps': gradient_accumulation_steps,
366
  'epochs': epochs,
367
  'learning_rate': learning_rate,
 
368
  'lora_r': lora_r,
369
  'lora_alpha': lora_alpha,
370
  'lora_dropout': lora_dropout,
 
371
  'model_name': model_name,
372
  }, indent=2)}
373
 
@@ -436,7 +443,22 @@ Train data (first 10):
436
  'prompt_template': template,
437
  'dataset_name': dataset_name,
438
  'dataset_rows': len(train_data),
439
- 'timestamp': time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  }
441
  json.dump(info, info_json_file, indent=2)
442
 
@@ -454,12 +476,12 @@ Train data (first 10):
454
  epochs, # num_epochs
455
  learning_rate, # learning_rate
456
  max_seq_length, # cutoff_len
457
- 0, # val_set_size
458
  lora_r, # lora_r
459
  lora_alpha, # lora_alpha
460
  lora_dropout, # lora_dropout
461
- ["q_proj", "v_proj"], # lora_target_modules
462
- True, # train_on_inputs
463
  False, # group_by_length
464
  None, # resume_from_checkpoint
465
  training_callbacks # callbacks
@@ -623,11 +645,20 @@ def finetune_ui():
623
  )
624
  )
625
 
626
- max_seq_length = gr.Slider(
627
- minimum=1, maximum=4096, value=512,
628
- label="Max Sequence Length",
629
- info="The maximum length of each sample text sequence. Sequences longer than this will be truncated."
630
- )
 
 
 
 
 
 
 
 
 
631
 
632
  with gr.Row():
633
  micro_batch_size_default_value = 1
@@ -663,6 +694,12 @@ def finetune_ui():
663
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
664
  )
665
 
 
 
 
 
 
 
666
  with gr.Column():
667
  lora_r = gr.Slider(
668
  minimum=1, maximum=16, step=1, value=8,
@@ -682,6 +719,12 @@ def finetune_ui():
682
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
683
  )
684
 
 
 
 
 
 
 
685
  with gr.Column():
686
  model_name = gr.Textbox(
687
  lines=1, label="LoRA Model Name", value=random_name,
@@ -712,13 +755,16 @@ def finetune_ui():
712
  fn=do_train,
713
  inputs=(dataset_inputs + [
714
  max_seq_length,
 
715
  micro_batch_size,
716
  gradient_accumulation_steps,
717
  epochs,
718
  learning_rate,
 
719
  lora_r,
720
  lora_alpha,
721
  lora_dropout,
 
722
  model_name
723
  ]),
724
  outputs=train_output
 
258
  dataset_plain_text_data_separator,
259
  # Training Options
260
  max_seq_length,
261
+ evaluate_data_percentage,
262
  micro_batch_size,
263
  gradient_accumulation_steps,
264
  epochs,
265
  learning_rate,
266
+ train_on_inputs,
267
  lora_r,
268
  lora_alpha,
269
  lora_dropout,
270
+ lora_target_modules,
271
  model_name,
272
  progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
273
  ):
 
327
  data = process_json_dataset(data)
328
 
329
  data_count = len(data)
330
+ evaluate_data_count = math.ceil(data_count * evaluate_data_percentage)
331
 
332
  train_data = [
333
  {
 
365
 
366
  Train options: {json.dumps({
367
  'max_seq_length': max_seq_length,
368
+ 'val_set_size': evaluate_data_count,
369
  'micro_batch_size': micro_batch_size,
370
  'gradient_accumulation_steps': gradient_accumulation_steps,
371
  'epochs': epochs,
372
  'learning_rate': learning_rate,
373
+ 'train_on_inputs': train_on_inputs,
374
  'lora_r': lora_r,
375
  'lora_alpha': lora_alpha,
376
  'lora_dropout': lora_dropout,
377
+ 'lora_target_modules': lora_target_modules,
378
  'model_name': model_name,
379
  }, indent=2)}
380
 
 
443
  'prompt_template': template,
444
  'dataset_name': dataset_name,
445
  'dataset_rows': len(train_data),
446
+ 'timestamp': time.time(),
447
+
448
+ 'max_seq_length': max_seq_length,
449
+ 'train_on_inputs': train_on_inputs,
450
+
451
+ 'micro_batch_size': micro_batch_size,
452
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
453
+ 'epochs': epochs,
454
+ 'learning_rate': learning_rate,
455
+
456
+ 'evaluate_data_percentage': evaluate_data_percentage,
457
+
458
+ 'lora_r': lora_r,
459
+ 'lora_alpha': lora_alpha,
460
+ 'lora_dropout': lora_dropout,
461
+ 'lora_target_modules': lora_target_modules,
462
  }
463
  json.dump(info, info_json_file, indent=2)
464
 
 
476
  epochs, # num_epochs
477
  learning_rate, # learning_rate
478
  max_seq_length, # cutoff_len
479
+ evaluate_data_count, # val_set_size
480
  lora_r, # lora_r
481
  lora_alpha, # lora_alpha
482
  lora_dropout, # lora_dropout
483
+ lora_target_modules, # lora_target_modules
484
+ train_on_inputs, # train_on_inputs
485
  False, # group_by_length
486
  None, # resume_from_checkpoint
487
  training_callbacks # callbacks
 
645
  )
646
  )
647
 
648
+ with gr.Row():
649
+ max_seq_length = gr.Slider(
650
+ minimum=1, maximum=4096, value=512,
651
+ label="Max Sequence Length",
652
+ info="The maximum length of each sample text sequence. Sequences longer than this will be truncated.",
653
+ elem_id="finetune_max_seq_length"
654
+ )
655
+
656
+ train_on_inputs = gr.Checkbox(
657
+ label="Train on Inputs",
658
+ value=True,
659
+ info="If not enabled, inputs will be masked out in loss.",
660
+ elem_id="finetune_train_on_inputs"
661
+ )
662
 
663
  with gr.Row():
664
  micro_batch_size_default_value = 1
 
694
  info="The initial learning rate for the optimizer. A higher learning rate may speed up convergence but also cause instability or divergence. A lower learning rate may require more steps to reach optimal performance but also avoid overshooting or oscillating around local minima."
695
  )
696
 
697
+ evaluate_data_percentage = gr.Slider(
698
+ minimum=0, maximum=0.5, step=0.001, value=0.03,
699
+ label="Evaluation Data Percentage",
700
+ info="The percentage of data to be used for evaluation. This percentage of data will not be used for training and will be used to assess the performance of the model during the process."
701
+ )
702
+
703
  with gr.Column():
704
  lora_r = gr.Slider(
705
  minimum=1, maximum=16, step=1, value=8,
 
719
  info="The dropout probability for LoRA, which controls the fraction of LoRA parameters that are set to zero during training. A larger lora_dropout increases the regularization effect of LoRA but also increases the risk of underfitting."
720
  )
721
 
722
+ lora_target_modules = gr.CheckboxGroup(
723
+ label="LoRA Target Modules",
724
+ choices=["q_proj", "k_proj", "v_proj", "o_proj"],
725
+ value=["q_proj", "v_proj"],
726
+ )
727
+
728
  with gr.Column():
729
  model_name = gr.Textbox(
730
  lines=1, label="LoRA Model Name", value=random_name,
 
755
  fn=do_train,
756
  inputs=(dataset_inputs + [
757
  max_seq_length,
758
+ evaluate_data_percentage,
759
  micro_batch_size,
760
  gradient_accumulation_steps,
761
  epochs,
762
  learning_rate,
763
+ train_on_inputs,
764
  lora_r,
765
  lora_alpha,
766
  lora_dropout,
767
+ lora_target_modules,
768
  model_name
769
  ]),
770
  outputs=train_output
llama_lora/ui/main_page.py CHANGED
@@ -428,6 +428,9 @@ def main_page_custom_css():
428
  white-space: pre-wrap;
429
  }
430
 
 
 
 
431
 
432
  @media screen and (max-width: 392px) {
433
  #inference_lora_model, #finetune_template {
 
428
  white-space: pre-wrap;
429
  }
430
 
431
+ #finetune_max_seq_length {
432
+ flex: 2;
433
+ }
434
 
435
  @media screen and (max-width: 392px) {
436
  #inference_lora_model, #finetune_template {