Damian Stewart commited on
Commit
bf1e262
1 Parent(s): 52c8f3c

allow multiple train prompts

Browse files
Files changed (3) hide show
  1. app.py +49 -37
  2. memory_efficiency.py +4 -1
  3. train.py +75 -61
app.py CHANGED
@@ -12,15 +12,20 @@ from train import train, training_should_cancel
12
  import os
13
 
14
  model_map = {}
15
- def populate_model_map():
 
 
16
  global model_map
 
17
  for model_file in os.listdir('models'):
18
  path = 'models/' + model_file
19
  if any([existing_path == path for existing_path in model_map.values()]):
20
  continue
21
  model_map[model_file] = path
22
- return model_map
23
- model_map = populate_model_map()
 
 
24
 
25
  ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
26
  SPACE_ID = os.getenv('SPACE_ID')
@@ -64,6 +69,12 @@ class Demo:
64
 
65
  with gr.Column(scale=1):
66
 
 
 
 
 
 
 
67
  self.prompt_input_infr = gr.Text(
68
  placeholder="Enter prompt...",
69
  label="Prompt",
@@ -104,12 +115,6 @@ class Demo:
104
  interactive=True
105
  )
106
 
107
- self.base_repo_id_or_path_input_infr = gr.Text(
108
- label="Base model",
109
- value="CompVis/stable-diffusion-v1-4",
110
- info="Path or huggingface repo id of the base model that this edit was done against"
111
- )
112
-
113
  with gr.Column(scale=2):
114
 
115
  self.infr_button = gr.Button(
@@ -152,19 +157,10 @@ class Demo:
152
  info="Image size for training, should match the model's native image size"
153
  )
154
 
155
- self.train_sample_batch_size_input = gr.Slider(
156
- value=1,
157
- step=1,
158
- minimum=1,
159
- maximum=32,
160
- label="Sample generation batch size",
161
- info="Batch size for sample generation, larger needs more VRAM"
162
- )
163
-
164
- self.prompt_input = gr.Text(
165
- placeholder="Enter prompt...",
166
- label="Prompt to Erase",
167
- info="Prompt corresponding to concept to erase"
168
  )
169
 
170
  choices = ['ESD-x', 'ESD-self', 'ESD-u']
@@ -175,7 +171,7 @@ class Demo:
175
  choices=choices,
176
  value='ESD-x',
177
  label='Train Method',
178
- info='Method of training'
179
  )
180
 
181
  self.neg_guidance_input = gr.Number(
@@ -233,11 +229,21 @@ class Demo:
233
  value='',
234
  info="Negative prompts for use when generating sample images. One for each positive prompt, or leave empty for none."
235
  )
236
- self.train_validate_every_n_steps = gr.Number(
237
- label="Validate Every N Steps",
238
- value=20,
239
- info="Validation and sample generation will be run at intervals of this many steps"
240
- )
 
 
 
 
 
 
 
 
 
 
241
 
242
  with gr.Column(scale=1):
243
 
@@ -311,7 +317,7 @@ class Demo:
311
  train_event = self.train_button.click(self.train, inputs = [
312
  self.train_model_input,
313
  self.train_img_size_input,
314
- self.prompt_input,
315
  self.train_method_input,
316
  self.neg_guidance_input,
317
  self.iterations_input,
@@ -346,9 +352,9 @@ class Demo:
346
 
347
  def reload_models(self, model_dropdown):
348
  current_model_name = model_dropdown
349
- global model_map
350
- populate_model_map()
351
- return [self.model_dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
352
 
353
  def cancel_training(self):
354
  if self.training:
@@ -356,7 +362,7 @@ class Demo:
356
  print("cancellation requested...")
357
  return [gr.update(value="Cancelling...", interactive=True)]
358
 
359
- def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
360
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
361
  seed=-1, save_every=-1, sample_batch_size=1,
362
  validation_prompts: str=None, sample_positive_prompts: str=None, sample_negative_prompts: str=None, validate_every_n_steps=-1,
@@ -365,7 +371,7 @@ class Demo:
365
 
366
  :param repo_id_or_path:
367
  :param img_size:
368
- :param prompt:
369
  :param train_method:
370
  :param neg_guidance:
371
  :param iterations:
@@ -386,7 +392,7 @@ class Demo:
386
  if self.training:
387
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
388
 
389
- print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
390
  print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
391
  print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
392
  print(f" {'✅' if use_amp else '❌'} AMP")
@@ -409,11 +415,12 @@ class Demo:
409
  while True:
410
  randn = torch.randint(1, 10000000, (1,)).item()
411
  options = f'{"a8" if use_adamw8bit else ""}{"AM" if use_amp else ""}{"xf" if use_xformers else ""}{"gc" if use_gradient_checkpointing else ""}'
412
- save_path = f"models/{prompt.lower().replace(' ', '')}_{train_method}_ng{neg_guidance}_lr{lr}_iter{iterations}_seed{seed}_{options}__{randn}.pt"
413
  if not os.path.exists(save_path):
414
  break
415
  # repeat until a not-in-use path is found
416
 
 
417
  validation_prompts = [] if validation_prompts is None else [p for p in validation_prompts.split('\n') if len(p)>0]
418
  sample_positive_prompts = [] if sample_positive_prompts is None else [p for p in sample_positive_prompts.split('\n') if len(p)>0]
419
  sample_negative_prompts = [] if sample_negative_prompts is None else sample_negative_prompts.split('\n')
@@ -425,7 +432,7 @@ class Demo:
425
  self.training = True
426
  self.train_cancel_button.update(interactive=True)
427
  batch_size = 1 # other batch sizes are non-functional
428
- save_path = train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
429
  use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
430
  seed=int(seed), save_every_n_steps=int(save_every),
431
  batch_size=int(batch_size), sample_batch_size=int(sample_batch_size),
@@ -476,6 +483,11 @@ class Demo:
476
  model_path = model_map[model_name]
477
  checkpoint = torch.load(model_path)
478
 
 
 
 
 
 
479
  self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
480
  finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
481
 
 
12
  import os
13
 
14
  model_map = {}
15
+ model_names_list = []
16
+
17
+ def populate_global_model_map():
18
  global model_map
19
+ global model_names_list
20
  for model_file in os.listdir('models'):
21
  path = 'models/' + model_file
22
  if any([existing_path == path for existing_path in model_map.values()]):
23
  continue
24
  model_map[model_file] = path
25
+ model_names_list.clear()
26
+ model_names_list.extend(model_map.keys())
27
+
28
+ populate_global_model_map()
29
 
30
  ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
31
  SPACE_ID = os.getenv('SPACE_ID')
 
69
 
70
  with gr.Column(scale=1):
71
 
72
+ self.base_repo_id_or_path_input_infr = gr.Text(
73
+ label="Base model",
74
+ value="CompVis/stable-diffusion-v1-4",
75
+ info="Path or huggingface repo id of the base model that this edit was done against"
76
+ )
77
+
78
  self.prompt_input_infr = gr.Text(
79
  placeholder="Enter prompt...",
80
  label="Prompt",
 
115
  interactive=True
116
  )
117
 
 
 
 
 
 
 
118
  with gr.Column(scale=2):
119
 
120
  self.infr_button = gr.Button(
 
157
  info="Image size for training, should match the model's native image size"
158
  )
159
 
160
+ self.train_prompts_input = gr.Text(
161
+ placeholder="Enter prompts, one per line",
162
+ label="Prompts to Erase",
163
+ info="Prompts corresponding to concepts to erase, one per line"
 
 
 
 
 
 
 
 
 
164
  )
165
 
166
  choices = ['ESD-x', 'ESD-self', 'ESD-u']
 
171
  choices=choices,
172
  value='ESD-x',
173
  label='Train Method',
174
+ info='Method of training. ESD-x uses the least VRAM, and you may get OOM errors with the other methods.'
175
  )
176
 
177
  self.neg_guidance_input = gr.Number(
 
229
  value='',
230
  info="Negative prompts for use when generating sample images. One for each positive prompt, or leave empty for none."
231
  )
232
+
233
+ with gr.Row():
234
+ self.train_sample_batch_size_input = gr.Slider(
235
+ value=1,
236
+ step=1,
237
+ minimum=1,
238
+ maximum=32,
239
+ label="Sample generation batch size",
240
+ info="Batch size for sample generation, larger needs more VRAM"
241
+ )
242
+ self.train_validate_every_n_steps = gr.Number(
243
+ label="Validate Every N Steps",
244
+ value=20,
245
+ info="Validation and sample generation will be run at intervals of this many steps"
246
+ )
247
 
248
  with gr.Column(scale=1):
249
 
 
317
  train_event = self.train_button.click(self.train, inputs = [
318
  self.train_model_input,
319
  self.train_img_size_input,
320
+ self.train_prompts_input,
321
  self.train_method_input,
322
  self.neg_guidance_input,
323
  self.iterations_input,
 
352
 
353
  def reload_models(self, model_dropdown):
354
  current_model_name = model_dropdown
355
+ populate_global_model_map()
356
+ global model_names_list
357
+ return [self.model_dropdown.update(choices=model_names_list, value=current_model_name)]
358
 
359
  def cancel_training(self):
360
  if self.training:
 
362
  print("cancellation requested...")
363
  return [gr.update(value="Cancelling...", interactive=True)]
364
 
365
+ def train(self, repo_id_or_path, img_size, prompts, train_method, neg_guidance, iterations, lr,
366
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
367
  seed=-1, save_every=-1, sample_batch_size=1,
368
  validation_prompts: str=None, sample_positive_prompts: str=None, sample_negative_prompts: str=None, validate_every_n_steps=-1,
 
371
 
372
  :param repo_id_or_path:
373
  :param img_size:
374
+ :param prompts:
375
  :param train_method:
376
  :param neg_guidance:
377
  :param iterations:
 
392
  if self.training:
393
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
394
 
395
+ print(f"Training {repo_id_or_path} at {img_size} to remove '{prompts}'.")
396
  print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
397
  print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
398
  print(f" {'✅' if use_amp else '❌'} AMP")
 
415
  while True:
416
  randn = torch.randint(1, 10000000, (1,)).item()
417
  options = f'{"a8" if use_adamw8bit else ""}{"AM" if use_amp else ""}{"xf" if use_xformers else ""}{"gc" if use_gradient_checkpointing else ""}'
418
+ save_path = f"models/{prompts[0].lower().replace(' ', '')}_{train_method}_ng{neg_guidance}_lr{lr}_iter{iterations}_seed{seed}_{options}__{randn}.pt"
419
  if not os.path.exists(save_path):
420
  break
421
  # repeat until a not-in-use path is found
422
 
423
+ prompts = [p for p in prompts.split('\n') if len(p)>0]
424
  validation_prompts = [] if validation_prompts is None else [p for p in validation_prompts.split('\n') if len(p)>0]
425
  sample_positive_prompts = [] if sample_positive_prompts is None else [p for p in sample_positive_prompts.split('\n') if len(p)>0]
426
  sample_negative_prompts = [] if sample_negative_prompts is None else sample_negative_prompts.split('\n')
 
432
  self.training = True
433
  self.train_cancel_button.update(interactive=True)
434
  batch_size = 1 # other batch sizes are non-functional
435
+ save_path = train(repo_id_or_path, img_size, prompts, modules, frozen, iterations, neg_guidance, lr, save_path,
436
  use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
437
  seed=int(seed), save_every_n_steps=int(save_every),
438
  batch_size=int(batch_size), sample_batch_size=int(sample_batch_size),
 
483
  model_path = model_map[model_name]
484
  checkpoint = torch.load(model_path)
485
 
486
+ if type(prompt) is str:
487
+ prompt = [prompt]
488
+ if type(negative_prompt) is str:
489
+ negative_prompt = [negative_prompt]
490
+
491
  self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval().half()
492
  finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval().half()
493
 
memory_efficiency.py CHANGED
@@ -66,10 +66,13 @@ class MemoryEfficiencyWrapper:
66
  growth_interval=25,
67
  )
68
 
69
- def step(self, optimizer, loss):
70
  self.grad_scaler.scale(loss).backward()
 
 
71
  self.grad_scaler.step(optimizer)
72
  self.grad_scaler.update()
 
73
 
74
  def __exit__(self, exc_type, exc_value, tb):
75
  if exc_type is not None:
 
66
  growth_interval=25,
67
  )
68
 
69
+ def backward(self, loss):
70
  self.grad_scaler.scale(loss).backward()
71
+
72
+ def step(self, optimizer):
73
  self.grad_scaler.step(optimizer)
74
  self.grad_scaler.update()
75
+ optimizer.zero_grad(set_to_none=True)
76
 
77
  def __exit__(self, exc_type, exc_value, tb):
78
  if exc_type is not None:
train.py CHANGED
@@ -1,6 +1,7 @@
1
  import os.path
2
  import random
3
  import multiprocessing
 
4
 
5
  from accelerate.utils import set_seed
6
  from diffusers import StableDiffusionPipeline
@@ -34,11 +35,13 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
34
  set_seed(validation_seed)
35
  criteria = torch.nn.MSELoss()
36
  negative_guidance = 1
37
- val_count = 5
38
 
39
  nsteps=50
40
  num_validation_batches = validation_embeddings.shape[0] // (batch_size*2)
41
 
 
 
 
42
  for i in tqdm(range(num_validation_batches)):
43
  if training_should_cancel.acquire(block=False):
44
  print("cancel requested, bailing")
@@ -58,9 +61,11 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
58
 
59
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
60
  accumulated_loss = (accumulated_loss or 0) + loss.item()
 
61
  logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
 
62
 
63
- num_sample_batches = sample_embeddings.shape[0] // (sample_batch_size*2)
64
  for i in tqdm(range(0, num_sample_batches)):
65
  print(f'making sample batch {i}...')
66
  if training_should_cancel.acquire(block=False):
@@ -82,9 +87,9 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
82
  images = pipeline(prompt_embeds=batch_prompt_embeds, #sample_embeddings[i*2+1:i*2+2],
83
  negative_prompt_embeds=batch_negative_prompt_embeds, # sample_embeddings[i*2:i*2+1],
84
  num_inference_steps=50)
85
- for j in range(sample_batch_size):
86
- image_tensor = transforms.ToTensor()(images.images[j])
87
- logger.add_image(f"samples/{i*sample_batch_size+j}", img_tensor=image_tensor, global_step=global_step)
88
 
89
  """
90
  with finetuner, torch.cuda.amp.autocast(enabled=use_amp):
@@ -97,20 +102,12 @@ def validate(diffuser: StableDiffuser, finetuner: FineTunedModel,
97
 
98
  torch.cuda.empty_cache()
99
 
100
- def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
101
  use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
102
  batch_size=1, sample_batch_size=1,
103
  save_every_n_steps=-1, validate_every_n_steps=-1,
104
  validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]):
105
 
106
- diffuser = None
107
- loss = None
108
- optimizer = None
109
- finetuner = None
110
- negative_latents = None
111
- neutral_latents = None
112
- positive_latents = None
113
-
114
  nsteps = 50
115
  print(f"using img_size of {img_size}")
116
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
@@ -118,7 +115,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
118
 
119
  memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
120
  use_gradient_checkpointing=use_gradient_checkpointing )
121
- with memory_efficiency_wrapper:
122
  diffuser.train()
123
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
124
  if use_adamw8bit:
@@ -139,7 +136,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
139
 
140
  with torch.no_grad():
141
  neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1)
142
- positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings([prompt], n_imgs=1)
143
  validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1)
144
  sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1)
145
 
@@ -173,51 +170,68 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
173
  start_loss = None
174
  max_prev_loss_count = 10
175
  try:
176
- for i in pbar:
177
- if training_should_cancel.acquire(block=False):
178
- print("cancel requested, bailing")
179
- return None
180
-
181
- with torch.no_grad():
182
- optimizer.zero_grad()
183
-
184
- iteration = torch.randint(1, nsteps - 1, (1,)).item()
185
-
186
- with finetuner:
187
- diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp)
188
 
189
- iteration = int(iteration / nsteps * 1000)
190
-
191
- with autocast(enabled=use_amp):
192
- positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
193
- neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1)
194
-
195
- with finetuner:
196
- with autocast(enabled=use_amp):
197
- negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
198
-
199
- positive_latents.requires_grad = False
200
- neutral_latents.requires_grad = False
201
-
202
- # loss = criteria(e_n, e_0) works the best try 5000 epochs
203
- loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
204
- memory_efficiency_wrapper.step(optimizer, loss)
205
- optimizer.zero_grad()
206
-
207
- logger.add_scalar("loss", loss.item(), global_step=i)
208
-
209
- # print moving average loss
210
- prev_losses.append(loss.detach().clone())
211
- if len(prev_losses) > max_prev_loss_count:
212
- prev_losses.pop(0)
213
- if start_loss is None:
214
- start_loss = prev_losses[-1]
215
- if len(prev_losses) >= max_prev_loss_count:
216
- moving_average_loss = sum(prev_losses) / len(prev_losses)
217
- print(
218
- f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
219
- else:
220
- print(f"step {i}: loss={loss.item()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0:
223
  torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt")
@@ -231,7 +245,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
231
  torch.save(finetuner.state_dict(), save_path)
232
  return save_path
233
  finally:
234
- del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents
235
  torch.cuda.empty_cache()
236
 
237
 
 
1
  import os.path
2
  import random
3
  import multiprocessing
4
+ import math
5
 
6
  from accelerate.utils import set_seed
7
  from diffusers import StableDiffusionPipeline
 
35
  set_seed(validation_seed)
36
  criteria = torch.nn.MSELoss()
37
  negative_guidance = 1
 
38
 
39
  nsteps=50
40
  num_validation_batches = validation_embeddings.shape[0] // (batch_size*2)
41
 
42
+ val_count = max(1, 5 // num_validation_batches)
43
+
44
+ val_total_loss = 0
45
  for i in tqdm(range(num_validation_batches)):
46
  if training_should_cancel.acquire(block=False):
47
  print("cancel requested, bailing")
 
61
 
62
  loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
63
  accumulated_loss = (accumulated_loss or 0) + loss.item()
64
+ val_total_loss += loss.item()
65
  logger.add_scalar(f"loss/val_{i}", accumulated_loss/val_count, global_step=global_step)
66
+ logger.add_scalar(f"loss/_val_all_combined", val_total_loss/(val_count*num_validation_batches), global_step=global_step)
67
 
68
+ num_sample_batches = int(math.ceil(sample_embeddings.shape[0] / (sample_batch_size*2)))
69
  for i in tqdm(range(0, num_sample_batches)):
70
  print(f'making sample batch {i}...')
71
  if training_should_cancel.acquire(block=False):
 
87
  images = pipeline(prompt_embeds=batch_prompt_embeds, #sample_embeddings[i*2+1:i*2+2],
88
  negative_prompt_embeds=batch_negative_prompt_embeds, # sample_embeddings[i*2:i*2+1],
89
  num_inference_steps=50)
90
+ for image_index, image in enumerate(images.images):
91
+ image_tensor = transforms.ToTensor()(image)
92
+ logger.add_image(f"samples/{i*sample_batch_size+image_index}", img_tensor=image_tensor, global_step=global_step)
93
 
94
  """
95
  with finetuner, torch.cuda.amp.autocast(enabled=use_amp):
 
102
 
103
  torch.cuda.empty_cache()
104
 
105
+ def train(repo_id_or_path, img_size, prompts, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
106
  use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1,
107
  batch_size=1, sample_batch_size=1,
108
  save_every_n_steps=-1, validate_every_n_steps=-1,
109
  validation_prompts=[], sample_positive_prompts=[], sample_negative_prompts=[]):
110
 
 
 
 
 
 
 
 
 
111
  nsteps = 50
112
  print(f"using img_size of {img_size}")
113
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path, native_img_size=img_size).to('cuda')
 
115
 
116
  memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
117
  use_gradient_checkpointing=use_gradient_checkpointing )
118
+ with (((((memory_efficiency_wrapper))))):
119
  diffuser.train()
120
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
121
  if use_adamw8bit:
 
136
 
137
  with torch.no_grad():
138
  neutral_text_embeddings = diffuser.get_cond_and_uncond_embeddings([''], n_imgs=1)
139
+ all_positive_text_embeddings = diffuser.get_cond_and_uncond_embeddings(prompts, n_imgs=1)
140
  validation_embeddings = diffuser.get_cond_and_uncond_embeddings(validation_prompts, n_imgs=1)
141
  sample_embeddings = diffuser.get_cond_and_uncond_embeddings(sample_positive_prompts, sample_negative_prompts, n_imgs=1)
142
 
 
170
  start_loss = None
171
  max_prev_loss_count = 10
172
  try:
173
+ loss=None
174
+ negative_latents=None
175
+ neutral_latents=None
176
+ positive_latents=None
 
 
 
 
 
 
 
 
177
 
178
+ num_prompts = all_positive_text_embeddings.shape[0] // 2
179
+ for i in pbar:
180
+ try:
181
+ loss = None
182
+ negative_latents = None
183
+ positive_latents = None
184
+ neutral_latents = None
185
+ diffused_latents = None
186
+ for j in tqdm(range(num_prompts)):
187
+ positive_text_embeddings = all_positive_text_embeddings[j*2:j*2+2]
188
+ if training_should_cancel.acquire(block=False):
189
+ print("cancel requested, bailing")
190
+ return None
191
+
192
+ with torch.no_grad():
193
+ optimizer.zero_grad()
194
+
195
+ iteration = torch.randint(1, nsteps - 1, (1,)).item()
196
+
197
+ with finetuner:
198
+ diffused_latents = get_diffused_latents(diffuser, nsteps, positive_text_embeddings, iteration, use_amp)
199
+
200
+ iteration = int(iteration / nsteps * 1000)
201
+
202
+ with autocast(enabled=use_amp):
203
+ positive_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
204
+ neutral_latents = diffuser.predict_noise(iteration, diffused_latents, neutral_text_embeddings, guidance_scale=1)
205
+
206
+ with finetuner:
207
+ with autocast(enabled=use_amp):
208
+ negative_latents = diffuser.predict_noise(iteration, diffused_latents, positive_text_embeddings, guidance_scale=1)
209
+
210
+ positive_latents.requires_grad = False
211
+ neutral_latents.requires_grad = False
212
+
213
+ # loss = criteria(e_n, e_0) works the best try 5000 epochs
214
+ loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
215
+ memory_efficiency_wrapper.backward(loss)
216
+
217
+ logger.add_scalar("loss", loss.item(), global_step=i)
218
+
219
+ # print moving average loss
220
+ prev_losses.append(loss.detach().clone())
221
+ if len(prev_losses) > max_prev_loss_count:
222
+ prev_losses.pop(0)
223
+ if start_loss is None:
224
+ start_loss = prev_losses[-1]
225
+ if len(prev_losses) >= max_prev_loss_count:
226
+ moving_average_loss = sum(prev_losses) / len(prev_losses)
227
+ print(
228
+ f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
229
+ else:
230
+ print(f"step {i}: loss={loss.item()}")
231
+
232
+ memory_efficiency_wrapper.step(optimizer)
233
+ finally:
234
+ del loss, negative_latents, positive_latents, neutral_latents, diffused_latents
235
 
236
  if save_every_n_steps > 0 and ((i+1) % save_every_n_steps) == 0:
237
  torch.save(finetuner.state_dict(), save_path + f"__step_{i+1}.pt")
 
245
  torch.save(finetuner.state_dict(), save_path)
246
  return save_path
247
  finally:
248
+ del diffuser, optimizer, finetuner
249
  torch.cuda.empty_cache()
250
 
251