Damian Stewart commited on
Commit
6067469
1 Parent(s): c8aa68b

save every N steps and loss logging

Browse files
Files changed (4) hide show
  1. StableDiffuser.py +3 -3
  2. app.py +40 -18
  3. isolate_rng.py +73 -0
  4. train.py +22 -2
StableDiffuser.py CHANGED
@@ -95,8 +95,8 @@ class StableDiffuser(torch.nn.Module):
95
  def set_scheduler_timesteps(self, n_steps):
96
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
97
 
98
- def get_initial_latents(self, n_imgs, width, height, n_prompts, generator=None):
99
- noise = self.get_noise(n_imgs, width, height, generator=generator).repeat(n_prompts, 1, 1, 1)
100
  latents = noise * self.scheduler.init_noise_sigma
101
  return latents
102
 
@@ -199,7 +199,7 @@ class StableDiffuser(torch.nn.Module):
199
  prompts = [prompts]
200
 
201
  self.set_scheduler_timesteps(n_steps)
202
- latents = self.get_initial_latents(n_imgs, width, height, len(prompts), generator=generator)
203
  text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
204
  end_iteration = end_iteration or n_steps
205
  latents_steps, trace_steps = self.diffusion(
 
95
  def set_scheduler_timesteps(self, n_steps):
96
  self.scheduler.set_timesteps(n_steps, device=self.unet.device)
97
 
98
+ def get_initial_latents(self, n_imgs, height, width, n_prompts, generator=None):
99
+ noise = self.get_noise(n_imgs, height, width, generator=generator).repeat(n_prompts, 1, 1, 1)
100
  latents = noise * self.scheduler.init_noise_sigma
101
  return latents
102
 
 
199
  prompts = [prompts]
200
 
201
  self.set_scheduler_timesteps(n_steps)
202
+ latents = self.get_initial_latents(n_imgs, height, width, len(prompts), generator=generator)
203
  text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
204
  end_iteration = end_iteration or n_steps
205
  latents_steps, trace_steps = self.diffusion(
app.py CHANGED
@@ -10,22 +10,16 @@ from memory_efficiency import MemoryEfficiencyWrapper
10
  from train import train
11
 
12
  import os
13
- model_map = {'Van Gogh': 'models/vangogh.pt',
14
- 'Pablo Picasso': 'models/pablopicasso.pt',
15
- 'Car': 'models/car.pt',
16
- 'Garbage Truck': 'models/garbagetruck.pt',
17
- 'French Horn': 'models/frenchhorn.pt',
18
- 'Kilian Eng': 'models/kilianeng.pt',
19
- 'Thomas Kinkade': 'models/thomaskinkade.pt',
20
- 'Tyler Edlin': 'models/tyleredlin.pt',
21
- 'Kelly McKernan': 'models/kellymckernan.pt',
22
- 'Rembrandt': 'models/rembrandt.pt' }
23
- for model_file in os.listdir('models'):
24
- path = 'models/' + model_file
25
- if any([existing_path == path for existing_path in model_map.values()]):
26
- continue
27
- model_map[model_file] = path
28
 
 
 
 
 
 
 
 
 
 
29
 
30
  ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
31
  SPACE_ID = os.getenv('SPACE_ID')
@@ -85,6 +79,10 @@ class Demo:
85
  value='Van Gogh',
86
  interactive=True
87
  )
 
 
 
 
88
 
89
  self.seed_infr = gr.Number(
90
  label="Seed",
@@ -196,6 +194,11 @@ class Demo:
196
  label="Seed",
197
  info="Set to a fixed number for reproducible training results, or use -1 to pick randomly"
198
  )
 
 
 
 
 
199
 
200
  with gr.Column():
201
  self.train_memory_options = gr.Markdown(interactive=False,
@@ -215,6 +218,10 @@ class Demo:
215
  value="Train",
216
  )
217
 
 
 
 
 
218
  self.download = gr.Files()
219
 
220
  with gr.Tab("Export") as export_column:
@@ -268,7 +275,10 @@ class Demo:
268
  self.image_orig
269
  ]
270
  )
271
- self.train_button.click(self.train, inputs = [
 
 
 
272
  self.train_model_input,
273
  self.train_img_size_input,
274
  self.prompt_input,
@@ -281,9 +291,12 @@ class Demo:
281
  self.train_use_amp_input,
282
  self.train_use_gradient_checkpointing_input,
283
  self.train_seed_input,
 
284
  ],
285
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
286
  )
 
 
287
  self.export_button.click(self.export, inputs = [
288
  self.model_dropdown_export,
289
  self.base_repo_id_or_path_input_export,
@@ -293,9 +306,15 @@ class Demo:
293
  outputs=[self.export_status]
294
  )
295
 
 
 
 
 
 
 
296
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
297
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
298
- seed=-1,
299
  pbar = gr.Progress(track_tqdm=True)):
300
 
301
  if self.training:
@@ -331,10 +350,13 @@ class Demo:
331
 
332
  try:
333
  self.training = True
 
334
  train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
335
- use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=int(seed))
 
336
  finally:
337
  self.training = False
 
338
 
339
  torch.cuda.empty_cache()
340
 
 
10
  from train import train
11
 
12
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def populate_model_map():
15
+ model_map = {}
16
+ for model_file in os.listdir('models'):
17
+ path = 'models/' + model_file
18
+ if any([existing_path == path for existing_path in model_map.values()]):
19
+ continue
20
+ model_map[model_file] = path
21
+ return model_map
22
+ model_map = populate_model_map()
23
 
24
  ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
25
  SPACE_ID = os.getenv('SPACE_ID')
 
79
  value='Van Gogh',
80
  interactive=True
81
  )
82
+ self.model_reload_button = gr.Button(
83
+ value="🔄",
84
+ interactive=True
85
+ )
86
 
87
  self.seed_infr = gr.Number(
88
  label="Seed",
 
194
  label="Seed",
195
  info="Set to a fixed number for reproducible training results, or use -1 to pick randomly"
196
  )
197
+ self.train_save_every_input = gr.Number(
198
+ value=-1,
199
+ label="Save every N steps",
200
+ info="If >0, save the model throughout training at the given step interval."
201
+ )
202
 
203
  with gr.Column():
204
  self.train_memory_options = gr.Markdown(interactive=False,
 
218
  value="Train",
219
  )
220
 
221
+ self.train_cancel_button = gr.Button(
222
+ value="Cancel training"
223
+ )
224
+
225
  self.download = gr.Files()
226
 
227
  with gr.Tab("Export") as export_column:
 
275
  self.image_orig
276
  ]
277
  )
278
+ self.model_reload_button.click(self.reload_models,
279
+ inputs=[self.model_dropdown],
280
+ outputs=[self.model_dropdown])
281
+ train_event = self.train_button.click(self.train, inputs = [
282
  self.train_model_input,
283
  self.train_img_size_input,
284
  self.prompt_input,
 
291
  self.train_use_amp_input,
292
  self.train_use_gradient_checkpointing_input,
293
  self.train_seed_input,
294
+ self.train_save_every_input,
295
  ],
296
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
297
  )
298
+ self.train_cancel_button.click(lambda x: print("cancel pressed"), cancels=[train_event])
299
+
300
  self.export_button.click(self.export, inputs = [
301
  self.model_dropdown_export,
302
  self.base_repo_id_or_path_input_export,
 
306
  outputs=[self.export_status]
307
  )
308
 
309
+ def reload_models(self, model_dropdown):
310
+ current_model_name = model_dropdown
311
+ global model_map
312
+ model_map = populate_model_map()
313
+ return [gr.Dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
314
+
315
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
316
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
317
+ seed=-1, save_every=-1,
318
  pbar = gr.Progress(track_tqdm=True)):
319
 
320
  if self.training:
 
350
 
351
  try:
352
  self.training = True
353
+ self.train_cancel_button.update(interactive=True)
354
  train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
355
+ use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
356
+ seed=int(seed), save_every=int(save_every))
357
  finally:
358
  self.training = False
359
+ self.train_cancel_button.update(interactive=False)
360
 
361
  torch.cuda.empty_cache()
362
 
isolate_rng.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy/pasted from pytorch lightning
2
+ # https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py
3
+ # and
4
+ # https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py
5
+
6
+ # Copyright The Lightning team.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+
20
+ from contextlib import contextmanager
21
+ from typing import Generator, Dict, Any
22
+
23
+ import torch
24
+ import numpy as np
25
+ from random import getstate as python_get_rng_state
26
+ from random import setstate as python_set_rng_state
27
+
28
+
29
+ def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
30
+ """Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
31
+ states = {
32
+ "torch": torch.get_rng_state(),
33
+ "numpy": np.random.get_state(),
34
+ "python": python_get_rng_state(),
35
+ }
36
+ if include_cuda:
37
+ states["torch.cuda"] = torch.cuda.get_rng_state_all()
38
+ return states
39
+
40
+
41
+ def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
42
+ """Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
43
+ process."""
44
+ torch.set_rng_state(rng_state_dict["torch"])
45
+ # torch.cuda rng_state is only included since v1.8.
46
+ if "torch.cuda" in rng_state_dict:
47
+ torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
48
+ np.random.set_state(rng_state_dict["numpy"])
49
+ version, state, gauss = rng_state_dict["python"]
50
+ python_set_rng_state((version, tuple(state), gauss))
51
+
52
+
53
+ @contextmanager
54
+ def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
55
+ """A context manager that resets the global random state on exit to what it was before entering.
56
+ It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
57
+ Args:
58
+ include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator.
59
+ Set this to ``False`` when using the function in a forked process where CUDA re-initialization is
60
+ prohibited.
61
+ Example:
62
+ >>> import torch
63
+ >>> torch.manual_seed(1) # doctest: +ELLIPSIS
64
+ <torch._C.Generator object at ...>
65
+ >>> with isolate_rng():
66
+ ... [torch.rand(1) for _ in range(3)]
67
+ [tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
68
+ >>> torch.rand(1)
69
+ tensor([0.7576])
70
+ """
71
+ states = _collect_rng_states(include_cuda)
72
+ yield
73
+ _set_rng_states(states)
train.py CHANGED
@@ -1,4 +1,4 @@
1
- from random import random
2
 
3
  from accelerate.utils import set_seed
4
  from torch.cuda.amp import autocast
@@ -8,11 +8,12 @@ from finetuning import FineTunedModel
8
  import torch
9
  from tqdm import tqdm
10
 
 
11
  from memory_efficiency import MemoryEfficiencyWrapper
12
 
13
 
14
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
15
- use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1):
16
 
17
  nsteps = 50
18
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
@@ -54,6 +55,9 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
54
  seed = random.randint(0, 2 ** 30)
55
  set_seed(int(seed))
56
 
 
 
 
57
  for i in pbar:
58
  with torch.no_grad():
59
  diffuser.set_scheduler_timesteps(nsteps)
@@ -92,6 +96,22 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
92
  memory_efficiency_wrapper.step(optimizer, loss)
93
  optimizer.zero_grad()
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  torch.save(finetuner.state_dict(), save_path)
96
 
97
  del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
 
1
+ import random
2
 
3
  from accelerate.utils import set_seed
4
  from torch.cuda.amp import autocast
 
8
  import torch
9
  from tqdm import tqdm
10
 
11
+ from isolate_rng import isolate_rng
12
  from memory_efficiency import MemoryEfficiencyWrapper
13
 
14
 
15
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
16
+ use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every=-1):
17
 
18
  nsteps = 50
19
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
 
55
  seed = random.randint(0, 2 ** 30)
56
  set_seed(int(seed))
57
 
58
+ prev_losses = []
59
+ start_loss = None
60
+ max_prev_loss_count = 10
61
  for i in pbar:
62
  with torch.no_grad():
63
  diffuser.set_scheduler_timesteps(nsteps)
 
96
  memory_efficiency_wrapper.step(optimizer, loss)
97
  optimizer.zero_grad()
98
 
99
+ # print moving average loss
100
+ prev_losses.append(loss.detach().clone())
101
+ if len(prev_losses) > max_prev_loss_count:
102
+ prev_losses.pop(0)
103
+ if start_loss is None:
104
+ start_loss = prev_losses[-1]
105
+ if len(prev_losses) >= max_prev_loss_count:
106
+ moving_average_loss = sum(prev_losses) / len(prev_losses)
107
+ print(
108
+ f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
109
+ else:
110
+ print(f"step {i}: loss={loss.item()}")
111
+
112
+ if save_every > 0 and ((i % save_every) == (save_every-1)):
113
+ torch.save(finetuner.state_dict(), save_path + f"__step_{i}.pt")
114
+
115
  torch.save(finetuner.state_dict(), save_path)
116
 
117
  del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents