zetavg commited on
Commit
3daa16f
1 Parent(s): 1e27707

show loss/epoch chart on finetune ui

Browse files
llama_lora/ui/finetune/finetune_ui.py CHANGED
@@ -28,7 +28,8 @@ from .previewing import (
28
  )
29
  from .training import (
30
  do_train,
31
- render_training_status
 
32
  )
33
 
34
  register_css_style('finetune', relative_read_file(__file__, "style.css"))
@@ -773,10 +774,15 @@ def finetune_ui():
773
  )
774
 
775
  train_status = gr.HTML(
776
- "Training results will be shown here.",
777
  label="Train Output",
778
  elem_id="finetune_training_status")
779
 
 
 
 
 
 
780
  training_indicator = gr.HTML(
781
  "training_indicator", visible=False, elem_id="finetune_training_indicator")
782
 
@@ -787,7 +793,8 @@ def finetune_ui():
787
  continue_from_model,
788
  continue_from_checkpoint,
789
  ]),
790
- outputs=[train_status, training_indicator]
 
791
  )
792
 
793
  # controlled by JS, shows the confirm_abort_button
@@ -803,6 +810,12 @@ def finetune_ui():
803
  outputs=[train_status, training_indicator],
804
  every=0.2
805
  )
 
 
 
 
 
 
806
  finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
807
 
808
  # things_that_might_timeout.append(training_status_updates)
 
28
  )
29
  from .training import (
30
  do_train,
31
+ render_training_status,
32
+ render_loss_plot
33
  )
34
 
35
  register_css_style('finetune', relative_read_file(__file__, "style.css"))
 
774
  )
775
 
776
  train_status = gr.HTML(
777
+ "",
778
  label="Train Output",
779
  elem_id="finetune_training_status")
780
 
781
+ with gr.Column(visible=False, elem_id="finetune_loss_plot_container") as loss_plot_container:
782
+ loss_plot = gr.Plot(
783
+ visible=False, show_label=False,
784
+ elem_id="finetune_loss_plot")
785
+
786
  training_indicator = gr.HTML(
787
  "training_indicator", visible=False, elem_id="finetune_training_indicator")
788
 
 
793
  continue_from_model,
794
  continue_from_checkpoint,
795
  ]),
796
+ outputs=[train_status, training_indicator,
797
+ loss_plot_container, loss_plot]
798
  )
799
 
800
  # controlled by JS, shows the confirm_abort_button
 
810
  outputs=[train_status, training_indicator],
811
  every=0.2
812
  )
813
+ loss_plot_updates = finetune_ui_blocks.load(
814
+ fn=render_loss_plot,
815
+ inputs=None,
816
+ outputs=[loss_plot_container, loss_plot],
817
+ every=10
818
+ )
819
  finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
820
 
821
  # things_that_might_timeout.append(training_status_updates)
llama_lora/ui/finetune/style.css CHANGED
@@ -255,7 +255,9 @@
255
  display: none;
256
  }
257
 
258
- #finetune_training_status > .wrap {
 
 
259
  border: 0;
260
  background: transparent;
261
  pointer-events: none;
@@ -264,6 +266,17 @@
264
  left: 0;
265
  right: 0;
266
  }
 
 
 
 
 
 
 
 
 
 
 
267
  #finetune_training_status > .wrap .meta-text-center {
268
  transform: none !important;
269
  }
@@ -383,5 +396,18 @@
383
  /* background: var(--error-background-fill) !important; */
384
  border: 1px solid var(--error-border-color) !important;
385
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
  #finetune_training_indicator { display: none; }
 
255
  display: none;
256
  }
257
 
258
+ #finetune_training_status > .wrap,
259
+ #finetune_loss_plot_container > .wrap,
260
+ #finetune_loss_plot > .wrap {
261
  border: 0;
262
  background: transparent;
263
  pointer-events: none;
 
266
  left: 0;
267
  right: 0;
268
  }
269
+ #finetune_training_status > .wrap:not(.generating)::after {
270
+ content: "Refresh the page if this takes too long.";
271
+ position: absolute;
272
+ top: 0;
273
+ left: 0;
274
+ right: 0;
275
+ bottom: 0;
276
+ padding-top: 64px;
277
+ opacity: 0.5;
278
+ text-align: center;
279
+ }
280
  #finetune_training_status > .wrap .meta-text-center {
281
  transform: none !important;
282
  }
 
396
  /* background: var(--error-background-fill) !important; */
397
  border: 1px solid var(--error-border-color) !important;
398
  }
399
+ #finetune_loss_plot {
400
+ padding: var(--block-padding);
401
+ }
402
+ #finetune_loss_plot .altair {
403
+ overflow: auto !important;
404
+ }
405
+ #finetune_loss_plot .altair > * {
406
+ margin: auto !important;
407
+ }
408
+ #finetune_loss_plot .vega-embed summary {
409
+ border: 0;
410
+ box-shadow: none;
411
+ }
412
 
413
  #finetune_training_indicator { display: none; }
llama_lora/ui/finetune/training.py CHANGED
@@ -1,11 +1,14 @@
1
  import os
2
  import json
3
  import time
 
4
  import datetime
5
  import pytz
6
  import socket
7
  import threading
8
  import traceback
 
 
9
  import gradio as gr
10
 
11
  from huggingface_hub import try_to_load_from_cache, snapshot_download
@@ -71,7 +74,7 @@ def do_train(
71
  progress=gr.Progress(track_tqdm=False),
72
  ):
73
  if Global.is_training or Global.is_train_starting:
74
- return render_training_status()
75
 
76
  reset_training_status()
77
  Global.is_train_starting = True
@@ -206,6 +209,9 @@ def do_train(
206
  print(message)
207
 
208
  total_steps = 300
 
 
 
209
  for i in range(300):
210
  if (Global.should_stop_training):
211
  break
@@ -213,11 +219,14 @@ def do_train(
213
  current_step = i + 1
214
  total_epochs = 3
215
  current_epoch = i / 100
216
- log_history = []
217
 
218
  if (i > 20):
219
- loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
220
- log_history = [{'loss': loss}]
 
 
 
 
221
 
222
  update_training_states(
223
  total_steps=total_steps,
@@ -295,7 +304,7 @@ def do_train(
295
  finally:
296
  Global.is_train_starting = False
297
 
298
- return render_training_status()
299
 
300
 
301
  def render_training_status():
@@ -411,6 +420,51 @@ def render_training_status():
411
  return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
412
 
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  def format_time(seconds):
415
  hours, remainder = divmod(seconds, 3600)
416
  minutes, seconds = divmod(remainder, 60)
 
1
  import os
2
  import json
3
  import time
4
+ import math
5
  import datetime
6
  import pytz
7
  import socket
8
  import threading
9
  import traceback
10
+ import altair as alt
11
+ import pandas as pd
12
  import gradio as gr
13
 
14
  from huggingface_hub import try_to_load_from_cache, snapshot_download
 
74
  progress=gr.Progress(track_tqdm=False),
75
  ):
76
  if Global.is_training or Global.is_train_starting:
77
+ return render_training_status() + render_loss_plot()
78
 
79
  reset_training_status()
80
  Global.is_train_starting = True
 
209
  print(message)
210
 
211
  total_steps = 300
212
+ log_history = []
213
+ initial_loss = 2
214
+ loss_decay_rate = 0.8
215
  for i in range(300):
216
  if (Global.should_stop_training):
217
  break
 
219
  current_step = i + 1
220
  total_epochs = 3
221
  current_epoch = i / 100
 
222
 
223
  if (i > 20):
224
+ loss = initial_loss * math.exp(-loss_decay_rate * current_epoch)
225
+ log_history.append({
226
+ 'loss': loss,
227
+ 'learning_rate': 0.0001,
228
+ 'epoch': current_epoch
229
+ })
230
 
231
  update_training_states(
232
  total_steps=total_steps,
 
304
  finally:
305
  Global.is_train_starting = False
306
 
307
+ return render_training_status() + render_loss_plot()
308
 
309
 
310
  def render_training_status():
 
420
  return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
421
 
422
 
423
+ def render_loss_plot():
424
+ if len(Global.training_log_history) <= 2:
425
+ return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
426
+
427
+ training_log_history = Global.training_log_history
428
+
429
+ loss_data = [
430
+ {
431
+ 'type': 'train_loss' if 'loss' in item else 'eval_loss',
432
+ 'loss': item.get('loss') or item.get('eval_loss'),
433
+ 'epoch': item.get('epoch')
434
+ } for item in training_log_history
435
+ if ('loss' in item or 'eval_loss' in item)
436
+ and 'epoch' in item
437
+ ]
438
+
439
+ source = pd.DataFrame(loss_data)
440
+
441
+ highlight = alt.selection(
442
+ type='single', # type: ignore
443
+ on='mouseover', fields=['type'], nearest=True
444
+ )
445
+
446
+ base = alt.Chart(source).encode( # type: ignore
447
+ x='epoch:Q',
448
+ y='loss:Q',
449
+ color='type:N',
450
+ tooltip=['type:N', 'loss:Q', 'epoch:Q']
451
+ )
452
+
453
+ points = base.mark_circle().encode(
454
+ opacity=alt.value(0)
455
+ ).add_selection(
456
+ highlight
457
+ ).properties(
458
+ width=640
459
+ )
460
+
461
+ lines = base.mark_line().encode(
462
+ size=alt.condition(~highlight, alt.value(1), alt.value(3))
463
+ )
464
+
465
+ return (gr.Column.update(visible=True), gr.Plot.update(points + lines, visible=True))
466
+
467
+
468
  def format_time(seconds):
469
  hours, remainder = divmod(seconds, 3600)
470
  minutes, seconds = divmod(remainder, 60)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  accelerate
 
2
  appdirs
3
  bitsandbytes
4
  black
@@ -7,10 +8,11 @@ datasets
7
  fire
8
  git+https://github.com/huggingface/peft.git
9
  git+https://github.com/huggingface/transformers.git
 
10
  huggingface_hub
 
11
  numba
12
  nvidia-ml-py3
13
- gradio
14
- loralib
15
- sentencepiece
16
  random-word
 
 
1
  accelerate
2
+ altair
3
  appdirs
4
  bitsandbytes
5
  black
 
8
  fire
9
  git+https://github.com/huggingface/peft.git
10
  git+https://github.com/huggingface/transformers.git
11
+ gradio
12
  huggingface_hub
13
+ loralib
14
  numba
15
  nvidia-ml-py3
16
+ pandas
 
 
17
  random-word
18
+ sentencepiece