Damian Stewart commited on
Commit
c8aa68b
1 Parent(s): a29f551

fix seed and export

Browse files
Files changed (3) hide show
  1. StableDiffuser.py +1 -1
  2. app.py +11 -8
  3. train.py +1 -1
StableDiffuser.py CHANGED
@@ -63,7 +63,7 @@ class StableDiffuser(torch.nn.Module):
63
  def get_noise(self, batch_size, width, height, generator=None):
64
  param = list(self.parameters())[0]
65
  return torch.randn(
66
- (batch_size, self.unet.in_channels, width // 8, height // 8),
67
  generator=generator).type(param.dtype).to(param.device)
68
 
69
  def add_noise(self, latents, noise, step):
 
63
  def get_noise(self, batch_size, width, height, generator=None):
64
  param = list(self.parameters())[0]
65
  return torch.randn(
66
+ (batch_size, self.unet.config.in_channels, width // 8, height // 8),
67
  generator=generator).type(param.dtype).to(param.device)
68
 
69
  def add_noise(self, latents, noise, step):
app.py CHANGED
@@ -199,12 +199,13 @@ class Demo:
199
 
200
  with gr.Column():
201
  self.train_memory_options = gr.Markdown(interactive=False,
202
- value='Performance and VRAM usage optimizations, may not work on all devices.')
203
  with gr.Row():
204
  self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=True)
205
  self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
206
  self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
207
- self.train_use_gradient_checkpointing_input = gr.Checkbox(label="Gradient checkpointing", value=True)
 
208
 
209
  with gr.Column(scale=1):
210
 
@@ -248,9 +249,10 @@ class Demo:
248
  )
249
 
250
  with gr.Column(scale=1):
 
 
251
  self.export_button = gr.Button(
252
- value="Export",
253
- )
254
 
255
  self.infr_button.click(self.inference, inputs = [
256
  self.prompt_input_infr,
@@ -288,12 +290,12 @@ class Demo:
288
  self.save_path_input_export,
289
  self.save_half_export
290
  ],
291
- outputs=[self.export_button]
292
  )
293
 
294
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
295
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
296
- seed = -1,
297
  pbar = gr.Progress(track_tqdm=True)):
298
 
299
  if self.training:
@@ -330,7 +332,7 @@ class Demo:
330
  try:
331
  self.training = True
332
  train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
333
- use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=seed)
334
  finally:
335
  self.training = False
336
 
@@ -355,8 +357,9 @@ class Demo:
355
  with finetuner:
356
  if save_half:
357
  diffuser = diffuser.half()
358
- diffuser.pipeline.to(torch.float16, torch_device=diffuser.device)
359
  diffuser.pipeline.save_pretrained(save_path)
 
360
 
361
 
362
  def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
 
199
 
200
  with gr.Column():
201
  self.train_memory_options = gr.Markdown(interactive=False,
202
+ value='Performance and VRAM usage optimizations, may not work on all devices:')
203
  with gr.Row():
204
  self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=True)
205
  self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
206
  self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
207
+ self.train_use_gradient_checkpointing_input = gr.Checkbox(
208
+ label="Gradient checkpointing", value=False)
209
 
210
  with gr.Column(scale=1):
211
 
 
249
  )
250
 
251
  with gr.Column(scale=1):
252
+ self.export_status = gr.Button(
253
+ value='', variant='primary', label='Status', interactive=False)
254
  self.export_button = gr.Button(
255
+ value="Export")
 
256
 
257
  self.infr_button.click(self.inference, inputs = [
258
  self.prompt_input_infr,
 
290
  self.save_path_input_export,
291
  self.save_half_export
292
  ],
293
+ outputs=[self.export_status]
294
  )
295
 
296
  def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
297
  use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
298
+ seed=-1,
299
  pbar = gr.Progress(track_tqdm=True)):
300
 
301
  if self.training:
 
332
  try:
333
  self.training = True
334
  train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
335
+ use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=int(seed))
336
  finally:
337
  self.training = False
338
 
 
357
  with finetuner:
358
  if save_half:
359
  diffuser = diffuser.half()
360
+ diffuser.pipeline.to('cpu', torch_dtype=torch.float16)
361
  diffuser.pipeline.save_pretrained(save_path)
362
+ return [gr.update(value=f'Done! Your model is at {save_path}.')]
363
 
364
 
365
  def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
train.py CHANGED
@@ -52,7 +52,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
52
 
53
  if seed == -1:
54
  seed = random.randint(0, 2 ** 30)
55
- set_seed(seed)
56
 
57
  for i in pbar:
58
  with torch.no_grad():
 
52
 
53
  if seed == -1:
54
  seed = random.randint(0, 2 ** 30)
55
+ set_seed(int(seed))
56
 
57
  for i in pbar:
58
  with torch.no_grad():