Damian Stewart commited on
Commit
ab11bdd
1 Parent(s): fc73e59

actually use AMP=3x speedup

Browse files
Files changed (3) hide show
  1. StableDiffuser.py +4 -8
  2. app.py +23 -12
  3. train.py +13 -12
StableDiffuser.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  from baukit import TraceDict
5
  from diffusers import StableDiffusionPipeline
6
  from PIL import Image
7
- from torch.cuda.amp import GradScaler
8
  from tqdm.auto import tqdm
9
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
10
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
@@ -35,6 +34,7 @@ class StableDiffuser(torch.nn.Module):
35
 
36
  def __init__(self,
37
  scheduler='LMS',
 
38
  repo_id_or_path="CompVis/stable-diffusion-v1-4"):
39
 
40
  super().__init__()
@@ -46,6 +46,7 @@ class StableDiffuser(torch.nn.Module):
46
  self.tokenizer = self.pipeline.tokenizer
47
  self.text_encoder = self.pipeline.text_encoder
48
  self.safety_checker = self.pipeline.safety_checker
 
49
 
50
  if scheduler == 'LMS':
51
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
@@ -55,10 +56,8 @@ class StableDiffuser(torch.nn.Module):
55
  self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
56
 
57
  self.eval()
58
-
59
- @property
60
- def feature_extractor(self):
61
- return self.pipeline.feature_extractor
62
 
63
  def get_noise(self, batch_size, width, height, generator=None):
64
  param = list(self.parameters())[0]
@@ -226,9 +225,6 @@ class StableDiffuser(torch.nn.Module):
226
 
227
  return images_steps
228
 
229
- def save_pretrained(self, path, **kwargs):
230
- self.pipeline.save_pretrained(path, **kwargs)
231
-
232
 
233
  if __name__ == '__main__':
234
 
 
4
  from baukit import TraceDict
5
  from diffusers import StableDiffusionPipeline
6
  from PIL import Image
 
7
  from tqdm.auto import tqdm
8
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
9
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
 
34
 
35
  def __init__(self,
36
  scheduler='LMS',
37
+ keep_pipeline=False,
38
  repo_id_or_path="CompVis/stable-diffusion-v1-4"):
39
 
40
  super().__init__()
 
46
  self.tokenizer = self.pipeline.tokenizer
47
  self.text_encoder = self.pipeline.text_encoder
48
  self.safety_checker = self.pipeline.safety_checker
49
+ self.feature_extractor = self.pipeline.feature_extractor
50
 
51
  if scheduler == 'LMS':
52
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
 
56
  self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
57
 
58
  self.eval()
59
+ if not keep_pipeline:
60
+ del self.pipeline
 
 
61
 
62
  def get_noise(self, batch_size, width, height, generator=None):
63
  param = list(self.parameters())[0]
 
225
 
226
  return images_steps
227
 
 
 
 
228
 
229
  if __name__ == '__main__':
230
 
app.py CHANGED
@@ -162,9 +162,9 @@ class Demo:
162
  info="Prompt corresponding to concept to erase"
163
  )
164
 
165
- choices = ['ESD-x', 'ESD-self']
166
- if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40 or is_xformers_available():
167
- choices.append('ESD-u')
168
 
169
  self.train_method_input = gr.Dropdown(
170
  choices=choices,
@@ -274,7 +274,7 @@ class Demo:
274
  self.train_use_amp_input,
275
  #self.train_use_gradient_checkpointing_input
276
  ],
277
- outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
278
  )
279
  self.export_button.click(self.export, inputs = [
280
  self.model_dropdown_export,
@@ -286,12 +286,19 @@ class Demo:
286
  )
287
 
288
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
289
- use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=True,
290
  pbar = gr.Progress(track_tqdm=True)):
291
 
292
  if self.training:
293
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
294
 
 
 
 
 
 
 
 
295
  if train_method == 'ESD-x':
296
  modules = ".*attn2$"
297
  frozen = []
@@ -319,20 +326,24 @@ class Demo:
319
  new_model_name = f'*new* {os.path.basename(save_path)}'
320
  model_map[new_model_name] = save_path
321
 
322
- return [gr.update(interactive=True, value='Train'), gr.update(value=f'Done Training! \n '
323
- 'Try your model ({new_model_name}) in the "Test" tab'), save_path,
 
324
  gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
325
 
326
  def export(self, model_name, base_repo_id_or_path, save_path, save_half):
327
  model_path = model_map[model_name]
328
  checkpoint = torch.load(model_path)
329
- self.diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=base_repo_id_or_path).to('cuda').eval()
330
- finetuner = FineTunedModel.from_checkpoint(self.diffuser, checkpoint).eval()
 
 
 
331
  with finetuner:
332
  if save_half:
333
- self.diffuser = self.diffuser.half()
334
- self.diffuser.pipeline.to(torch.float16, torch_device=self.diffuser.device)
335
- self.diffuser.save_pretrained(save_path)
336
 
337
 
338
  def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
 
162
  info="Prompt corresponding to concept to erase"
163
  )
164
 
165
+ choices = ['ESD-x', 'ESD-self', 'ESD-u']
166
+ #if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40 or is_xformers_available():
167
+ # choices.append('ESD-u')
168
 
169
  self.train_method_input = gr.Dropdown(
170
  choices=choices,
 
274
  self.train_use_amp_input,
275
  #self.train_use_gradient_checkpointing_input
276
  ],
277
+ outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
278
  )
279
  self.export_button.click(self.export, inputs = [
280
  self.model_dropdown_export,
 
286
  )
287
 
288
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
289
+ use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
290
  pbar = gr.Progress(track_tqdm=True)):
291
 
292
  if self.training:
293
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
294
 
295
+ print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
296
+ print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
297
+ print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
298
+ print(f" {'✅' if use_amp else '❌'} AMP")
299
+ print(f" {'✅' if use_xformers else '❌'} xformers")
300
+ print(f" {'✅' if use_adamw8bit else '❌'} 8-bit AdamW")
301
+
302
  if train_method == 'ESD-x':
303
  modules = ".*attn2$"
304
  frozen = []
 
326
  new_model_name = f'*new* {os.path.basename(save_path)}'
327
  model_map[new_model_name] = save_path
328
 
329
+ return [gr.update(interactive=True, value='Train'),
330
+ gr.update(value=f'Done Training! Try your model ({new_model_name}) in the "Test" tab'),
331
+ save_path,
332
  gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
333
 
334
  def export(self, model_name, base_repo_id_or_path, save_path, save_half):
335
  model_path = model_map[model_name]
336
  checkpoint = torch.load(model_path)
337
+ diffuser = StableDiffuser(scheduler='DDIM',
338
+ keep_pipeline=True,
339
+ repo_id_or_path=base_repo_id_or_path
340
+ ).eval()
341
+ finetuner = FineTunedModel.from_checkpoint(diffuser, checkpoint).eval()
342
  with finetuner:
343
  if save_half:
344
+ diffuser = diffuser.half()
345
+ diffuser.pipeline.to(torch.float16, torch_device=diffuser.device)
346
+ diffuser.pipeline.save_pretrained(save_path)
347
 
348
 
349
  def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
train.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from StableDiffuser import StableDiffuser
2
  from finetuning import FineTunedModel
3
  import torch
@@ -8,20 +10,17 @@ from memory_efficiency import MemoryEfficiencyWrapper
8
 
9
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
10
  use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
11
-
12
- nsteps = 50
13
 
 
14
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
15
 
16
  memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
17
  use_gradient_checkpointing=use_gradient_checkpointing )
18
  with memory_efficiency_wrapper:
19
-
20
  diffuser.train()
21
-
22
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
23
-
24
  if use_adamw8bit:
 
25
  import bitsandbytes as bnb
26
  optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
27
  lr=lr,
@@ -30,13 +29,13 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
30
  eps=1e-8
31
  )
32
  else:
 
33
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
34
  criteria = torch.nn.MSELoss()
35
 
36
  pbar = tqdm(range(iterations))
37
 
38
  with torch.no_grad():
39
-
40
  neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
41
  positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
42
 
@@ -56,7 +55,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
56
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
57
  latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
58
 
59
- with finetuner:
60
  latents_steps, _ = diffuser.diffusion(
61
  latents,
62
  positive_text_embeddings,
@@ -67,19 +66,21 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
67
  )
68
 
69
  diffuser.set_scheduler_timesteps(1000)
70
-
71
  iteration = int(iteration / nsteps * 1000)
72
 
73
- positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
74
- neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
 
75
 
76
  with finetuner:
77
- negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
 
78
 
79
  positive_latents.requires_grad = False
80
  neutral_latents.requires_grad = False
81
 
82
- loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
 
83
  memory_efficiency_wrapper.step(optimizer, loss)
84
  optimizer.step()
85
 
 
1
+ from torch.cuda.amp import autocast
2
+
3
  from StableDiffuser import StableDiffuser
4
  from finetuning import FineTunedModel
5
  import torch
 
10
 
11
  def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
12
  use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
 
 
13
 
14
+ nsteps = 50
15
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
16
 
17
  memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
18
  use_gradient_checkpointing=use_gradient_checkpointing )
19
  with memory_efficiency_wrapper:
 
20
  diffuser.train()
 
21
  finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
 
22
  if use_adamw8bit:
23
+ use print("using AdamW 8Bit optimizer")
24
  import bitsandbytes as bnb
25
  optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
26
  lr=lr,
 
29
  eps=1e-8
30
  )
31
  else:
32
+ print("using Adam optimizer")
33
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
34
  criteria = torch.nn.MSELoss()
35
 
36
  pbar = tqdm(range(iterations))
37
 
38
  with torch.no_grad():
 
39
  neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
40
  positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
41
 
 
55
  iteration = torch.randint(1, nsteps - 1, (1,)).item()
56
  latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
57
 
58
+ with autocast(enabled=use_amp), finetuner:
59
  latents_steps, _ = diffuser.diffusion(
60
  latents,
61
  positive_text_embeddings,
 
66
  )
67
 
68
  diffuser.set_scheduler_timesteps(1000)
 
69
  iteration = int(iteration / nsteps * 1000)
70
 
71
+ with autocast(enabled=use_amp):
72
+ positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
73
+ neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
74
 
75
  with finetuner:
76
+ with autocast(enabled=use_amp):
77
+ negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
78
 
79
  positive_latents.requires_grad = False
80
  neutral_latents.requires_grad = False
81
 
82
+ # loss = criteria(e_n, e_0) works the best try 5000 epochs
83
+ loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
84
  memory_efficiency_wrapper.step(optimizer, loss)
85
  optimizer.step()
86