zetavg commited on
Commit
341c612
1 Parent(s): 9ee06c7
Files changed (1) hide show
  1. llama_lora/ui/finetune_ui.py +7 -10
llama_lora/ui/finetune_ui.py CHANGED
@@ -3,6 +3,7 @@ import json
3
  import time
4
  from datetime import datetime
5
  import gradio as gr
 
6
  from random_word import RandomWords
7
 
8
  from transformers import TrainerCallback
@@ -334,25 +335,21 @@ Train data (first 10):
334
  return message
335
 
336
  class UiTrainerCallback(TrainerCallback):
337
- def on_epoch_begin(self, args, state, control, **kwargs):
338
  if Global.should_stop_training:
339
  control.should_training_stop = True
340
  total_steps = (
341
  state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
342
  progress(
343
  (state.global_step, total_steps),
344
- desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
345
  )
346
 
 
 
 
347
  def on_step_end(self, args, state, control, **kwargs):
348
- if Global.should_stop_training:
349
- control.should_training_stop = True
350
- total_steps = (
351
- state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
352
- progress(
353
- (state.global_step, total_steps),
354
- desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
355
- )
356
 
357
  training_callbacks = [UiTrainerCallback]
358
 
 
3
  import time
4
  from datetime import datetime
5
  import gradio as gr
6
+ import math
7
  from random_word import RandomWords
8
 
9
  from transformers import TrainerCallback
 
335
  return message
336
 
337
  class UiTrainerCallback(TrainerCallback):
338
+ def _on_progress(self, args, state, control):
339
  if Global.should_stop_training:
340
  control.should_training_stop = True
341
  total_steps = (
342
  state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
343
  progress(
344
  (state.global_step, total_steps),
345
+ desc=f"Training... (Epoch {math.ceil(state.epoch)}/{epochs}, Step {state.global_step}/{total_steps})"
346
  )
347
 
348
+ def on_epoch_begin(self, args, state, control, **kwargs):
349
+ self._on_progress(args, state, control)
350
+
351
  def on_step_end(self, args, state, control, **kwargs):
352
+ self._on_progress(args, state, control)
 
 
 
 
 
 
 
353
 
354
  training_callbacks = [UiTrainerCallback]
355