zetavg commited on
Commit
300b660
1 Parent(s): 85fb243
llama_lora/lib/finetune.py CHANGED
@@ -70,7 +70,13 @@ def train(
70
  wandb_tags: List[str] = [],
71
  wandb_watch: str = "false", # options: false | gradients | all
72
  wandb_log_model: str = "true", # options: false | true
 
73
  ):
 
 
 
 
 
74
  if lora_modules_to_save is not None and len(lora_modules_to_save) <= 0:
75
  lora_modules_to_save = None
76
 
@@ -171,6 +177,16 @@ def train(
171
  if ddp:
172
  device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
173
 
 
 
 
 
 
 
 
 
 
 
174
  model = base_model
175
  if isinstance(model, str):
176
  model_name = model
@@ -216,51 +232,16 @@ def train(
216
  # )
217
  tokenizer.padding_side = "left" # Allow batched inference
218
 
219
- def tokenize(prompt, add_eos_token=True):
220
- # there's probably a way to do this with the tokenizer settings
221
- # but again, gotta move fast
222
- result = tokenizer(
223
- prompt,
224
- truncation=True,
225
- max_length=cutoff_len,
226
- padding=False,
227
- return_tensors=None,
228
- )
229
- if (
230
- result["input_ids"][-1] != tokenizer.eos_token_id
231
- and len(result["input_ids"]) < cutoff_len
232
- and add_eos_token
233
- ):
234
- result["input_ids"].append(tokenizer.eos_token_id)
235
- result["attention_mask"].append(1)
236
-
237
- result["labels"] = result["input_ids"].copy()
238
-
239
- return result
240
-
241
- def generate_and_tokenize_prompt(data_point):
242
- full_prompt = data_point["prompt"] + data_point["completion"]
243
- tokenized_full_prompt = tokenize(full_prompt)
244
- if not train_on_inputs:
245
- user_prompt = data_point["prompt"]
246
- tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
247
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
248
-
249
- tokenized_full_prompt["labels"] = [
250
- -100
251
- ] * user_prompt_len + tokenized_full_prompt["labels"][
252
- user_prompt_len:
253
- ] # could be sped up, probably
254
- return tokenized_full_prompt
255
-
256
- # will fail anyway.
257
  try:
258
  model = prepare_model_for_int8_training(model)
259
  except Exception as e:
260
  print(
261
  f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
262
 
263
- # model = prepare_model_for_int8_training(model)
 
 
 
264
 
265
  lora_config_args = {
266
  'r': lora_r,
@@ -279,12 +260,6 @@ def train(
279
  if bf16:
280
  model = model.to(torch.bfloat16)
281
 
282
- # If train_data is a list, convert it to datasets.Dataset
283
- if isinstance(train_data, list):
284
- with open(os.path.join(output_dir, "train_data_samples.json"), 'w') as file:
285
- json.dump(list(train_data[:100]), file, indent=2)
286
- train_data = Dataset.from_list(train_data)
287
-
288
  if resume_from_checkpoint:
289
  # Check the available weights and load them
290
  checkpoint_name = os.path.join(
@@ -320,6 +295,54 @@ def train(
320
  wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
321
  "trainable%": 100 * trainable_params / all_params}})
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  if val_set_size > 0:
324
  train_val = train_data.train_test_split(
325
  test_size=val_set_size, shuffle=True, seed=42
@@ -339,6 +362,11 @@ def train(
339
  model.is_parallelizable = True
340
  model.model_parallel = True
341
 
 
 
 
 
 
342
  # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
343
  training_args = {
344
  'output_dir': output_dir,
 
70
  wandb_tags: List[str] = [],
71
  wandb_watch: str = "false", # options: false | gradients | all
72
  wandb_log_model: str = "true", # options: false | true
73
+ status_message_callback: Any = None,
74
  ):
75
+ if status_message_callback:
76
+ cb_result = status_message_callback("Preparing training...")
77
+ if cb_result:
78
+ return
79
+
80
  if lora_modules_to_save is not None and len(lora_modules_to_save) <= 0:
81
  lora_modules_to_save = None
82
 
 
177
  if ddp:
178
  device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
179
 
180
+ if status_message_callback:
181
+ if isinstance(base_model, str):
182
+ cb_result = status_message_callback(f"Preparing model '{base_model}' for training...")
183
+ if cb_result:
184
+ return
185
+ else:
186
+ cb_result = status_message_callback("Preparing model for training...")
187
+ if cb_result:
188
+ return
189
+
190
  model = base_model
191
  if isinstance(model, str):
192
  model_name = model
 
232
  # )
233
  tokenizer.padding_side = "left" # Allow batched inference
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  try:
236
  model = prepare_model_for_int8_training(model)
237
  except Exception as e:
238
  print(
239
  f"Got error while running prepare_model_for_int8_training(model), maybe the model has already be prepared. Original error: {e}.")
240
 
241
+ if status_message_callback:
242
+ cb_result = status_message_callback("Preparing PEFT model for training...")
243
+ if cb_result:
244
+ return
245
 
246
  lora_config_args = {
247
  'r': lora_r,
 
260
  if bf16:
261
  model = model.to(torch.bfloat16)
262
 
 
 
 
 
 
 
263
  if resume_from_checkpoint:
264
  # Check the available weights and load them
265
  checkpoint_name = os.path.join(
 
295
  wandb.config.update({"model": {"all_params": all_params, "trainable_params": trainable_params,
296
  "trainable%": 100 * trainable_params / all_params}})
297
 
298
+ if status_message_callback:
299
+ cb_result = status_message_callback("Preparing train data...")
300
+ if cb_result:
301
+ return
302
+
303
+ def tokenize(prompt, add_eos_token=True):
304
+ # there's probably a way to do this with the tokenizer settings
305
+ # but again, gotta move fast
306
+ result = tokenizer(
307
+ prompt,
308
+ truncation=True,
309
+ max_length=cutoff_len,
310
+ padding=False,
311
+ return_tensors=None,
312
+ )
313
+ if (
314
+ result["input_ids"][-1] != tokenizer.eos_token_id
315
+ and len(result["input_ids"]) < cutoff_len
316
+ and add_eos_token
317
+ ):
318
+ result["input_ids"].append(tokenizer.eos_token_id)
319
+ result["attention_mask"].append(1)
320
+
321
+ result["labels"] = result["input_ids"].copy()
322
+
323
+ return result
324
+
325
+ def generate_and_tokenize_prompt(data_point):
326
+ full_prompt = data_point["prompt"] + data_point["completion"]
327
+ tokenized_full_prompt = tokenize(full_prompt)
328
+ if not train_on_inputs:
329
+ user_prompt = data_point["prompt"]
330
+ tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
331
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
332
+
333
+ tokenized_full_prompt["labels"] = [
334
+ -100
335
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
336
+ user_prompt_len:
337
+ ] # could be sped up, probably
338
+ return tokenized_full_prompt
339
+
340
+ # If train_data is a list, convert it to datasets.Dataset
341
+ if isinstance(train_data, list):
342
+ with open(os.path.join(output_dir, "train_data_samples.json"), 'w') as file:
343
+ json.dump(list(train_data[:100]), file, indent=2)
344
+ train_data = Dataset.from_list(train_data)
345
+
346
  if val_set_size > 0:
347
  train_val = train_data.train_test_split(
348
  test_size=val_set_size, shuffle=True, seed=42
 
362
  model.is_parallelizable = True
363
  model.model_parallel = True
364
 
365
+ if status_message_callback:
366
+ cb_result = status_message_callback("Train starting...")
367
+ if cb_result:
368
+ return
369
+
370
  # https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
371
  training_args = {
372
  'output_dir': output_dir,
llama_lora/ui/finetune/finetune_ui.py CHANGED
@@ -309,6 +309,7 @@ def handle_lora_modules_to_save_add(choices, new_module, selected_modules):
309
 
310
  def do_abort_training():
311
  Global.should_stop_training = True
 
312
 
313
 
314
  def finetune_ui():
 
309
 
310
  def do_abort_training():
311
  Global.should_stop_training = True
312
+ Global.training_status_text = "Aborting..."
313
 
314
 
315
  def finetune_ui():
llama_lora/ui/finetune/training.py CHANGED
@@ -22,6 +22,13 @@ from ..trainer_callback import (
22
  from .data_processing import get_data_from_input
23
 
24
 
 
 
 
 
 
 
 
25
  def do_train(
26
  # Dataset
27
  template,
@@ -254,6 +261,7 @@ def do_train(
254
  train_output = Global.finetune_train_fn(
255
  train_data=train_data,
256
  callbacks=training_callbacks,
 
257
  **finetune_args,
258
  )
259
 
 
22
  from .data_processing import get_data_from_input
23
 
24
 
25
+ def status_message_callback(message):
26
+ if Global.should_stop_training:
27
+ return True
28
+
29
+ Global.training_status_text = message
30
+
31
+
32
  def do_train(
33
  # Dataset
34
  template,
 
261
  train_output = Global.finetune_train_fn(
262
  train_data=train_data,
263
  callbacks=training_callbacks,
264
+ status_message_callback=status_message_callback,
265
  **finetune_args,
266
  )
267
 
llama_lora/ui/trainer_callback.py CHANGED
@@ -57,6 +57,9 @@ def update_training_states(
57
  Global.training_log_history = log_history
58
  Global.training_eta = Global.training_eta_predictor.predict_eta(current_step, total_steps)
59
 
 
 
 
60
  last_history = None
61
  last_loss = None
62
  if len(Global.training_log_history) > 0:
 
57
  Global.training_log_history = log_history
58
  Global.training_eta = Global.training_eta_predictor.predict_eta(current_step, total_steps)
59
 
60
+ if Global.should_stop_training:
61
+ return
62
+
63
  last_history = None
64
  last_loss = None
65
  if len(Global.training_log_history) > 0: