zetavg commited on
Commit
d2eef14
1 Parent(s): 8177d08

finetune: load template and dataset from model

Browse files
Files changed (1) hide show
  1. llama_lora/ui/finetune_ui.py +18 -2
llama_lora/ui/finetune_ui.py CHANGED
@@ -621,6 +621,7 @@ def handle_continue_from_model_change(model_name):
621
 
622
  def handle_load_params_from_model(
623
  model_name,
 
624
  max_seq_length,
625
  evaluate_data_count,
626
  micro_batch_size,
@@ -654,6 +655,20 @@ def handle_load_params_from_model(
654
  lora_model_directory_path = os.path.join(
655
  lora_models_directory_path, model_name)
656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  data = {}
658
  possible_files = ["finetune_params.json", "finetune_args.json"]
659
  for file in possible_files:
@@ -747,6 +762,7 @@ def handle_load_params_from_model(
747
 
748
  return (
749
  gr.Markdown.update(value=message, visible=has_message),
 
750
  max_seq_length,
751
  evaluate_data_count,
752
  micro_batch_size,
@@ -1231,9 +1247,9 @@ def finetune_ui():
1231
  things_that_might_timeout.append(
1232
  load_params_from_model_btn.click(
1233
  fn=handle_load_params_from_model,
1234
- inputs=[continue_from_model] + finetune_args +
1235
  [lora_target_module_choices, lora_modules_to_save_choices],
1236
- outputs=[load_params_from_model_message] + finetune_args +
1237
  [lora_target_module_choices, lora_modules_to_save_choices]
1238
  )
1239
  )
 
621
 
622
  def handle_load_params_from_model(
623
  model_name,
624
+ template, load_dataset_from, dataset_from_data_dir,
625
  max_seq_length,
626
  evaluate_data_count,
627
  micro_batch_size,
 
655
  lora_model_directory_path = os.path.join(
656
  lora_models_directory_path, model_name)
657
 
658
+ try:
659
+ with open(os.path.join(lora_model_directory_path, "info.json"), "r") as f:
660
+ info = json.load(f)
661
+ if isinstance(info, dict):
662
+ model_prompt_template = info.get("prompt_template")
663
+ if model_prompt_template:
664
+ template = model_prompt_template
665
+ model_dataset_name = info.get("dataset_name")
666
+ if model_dataset_name and isinstance(model_dataset_name, str) and not model_dataset_name.startswith("N/A"):
667
+ load_dataset_from = "Data Dir"
668
+ dataset_from_data_dir = model_dataset_name
669
+ except FileNotFoundError:
670
+ pass
671
+
672
  data = {}
673
  possible_files = ["finetune_params.json", "finetune_args.json"]
674
  for file in possible_files:
 
762
 
763
  return (
764
  gr.Markdown.update(value=message, visible=has_message),
765
+ template, load_dataset_from, dataset_from_data_dir,
766
  max_seq_length,
767
  evaluate_data_count,
768
  micro_batch_size,
 
1247
  things_that_might_timeout.append(
1248
  load_params_from_model_btn.click(
1249
  fn=handle_load_params_from_model,
1250
+ inputs=[continue_from_model] + [template, load_dataset_from, dataset_from_data_dir] + finetune_args +
1251
  [lora_target_module_choices, lora_modules_to_save_choices],
1252
+ outputs=[load_params_from_model_message] + [template, load_dataset_from, dataset_from_data_dir] + finetune_args +
1253
  [lora_target_module_choices, lora_modules_to_save_choices]
1254
  )
1255
  )