Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
•
c8aa68b
1
Parent(s):
a29f551
fix seed and export
Browse files- StableDiffuser.py +1 -1
- app.py +11 -8
- train.py +1 -1
StableDiffuser.py
CHANGED
@@ -63,7 +63,7 @@ class StableDiffuser(torch.nn.Module):
|
|
63 |
def get_noise(self, batch_size, width, height, generator=None):
|
64 |
param = list(self.parameters())[0]
|
65 |
return torch.randn(
|
66 |
-
(batch_size, self.unet.in_channels, width // 8, height // 8),
|
67 |
generator=generator).type(param.dtype).to(param.device)
|
68 |
|
69 |
def add_noise(self, latents, noise, step):
|
|
|
63 |
def get_noise(self, batch_size, width, height, generator=None):
|
64 |
param = list(self.parameters())[0]
|
65 |
return torch.randn(
|
66 |
+
(batch_size, self.unet.config.in_channels, width // 8, height // 8),
|
67 |
generator=generator).type(param.dtype).to(param.device)
|
68 |
|
69 |
def add_noise(self, latents, noise, step):
|
app.py
CHANGED
@@ -199,12 +199,13 @@ class Demo:
|
|
199 |
|
200 |
with gr.Column():
|
201 |
self.train_memory_options = gr.Markdown(interactive=False,
|
202 |
-
value='Performance and VRAM usage optimizations, may not work on all devices
|
203 |
with gr.Row():
|
204 |
self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=True)
|
205 |
self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
|
206 |
self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
|
207 |
-
self.train_use_gradient_checkpointing_input = gr.Checkbox(
|
|
|
208 |
|
209 |
with gr.Column(scale=1):
|
210 |
|
@@ -248,9 +249,10 @@ class Demo:
|
|
248 |
)
|
249 |
|
250 |
with gr.Column(scale=1):
|
|
|
|
|
251 |
self.export_button = gr.Button(
|
252 |
-
value="Export"
|
253 |
-
)
|
254 |
|
255 |
self.infr_button.click(self.inference, inputs = [
|
256 |
self.prompt_input_infr,
|
@@ -288,12 +290,12 @@ class Demo:
|
|
288 |
self.save_path_input_export,
|
289 |
self.save_half_export
|
290 |
],
|
291 |
-
outputs=[self.
|
292 |
)
|
293 |
|
294 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
295 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
296 |
-
seed
|
297 |
pbar = gr.Progress(track_tqdm=True)):
|
298 |
|
299 |
if self.training:
|
@@ -330,7 +332,7 @@ class Demo:
|
|
330 |
try:
|
331 |
self.training = True
|
332 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
333 |
-
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=seed)
|
334 |
finally:
|
335 |
self.training = False
|
336 |
|
@@ -355,8 +357,9 @@ class Demo:
|
|
355 |
with finetuner:
|
356 |
if save_half:
|
357 |
diffuser = diffuser.half()
|
358 |
-
diffuser.pipeline.to(
|
359 |
diffuser.pipeline.save_pretrained(save_path)
|
|
|
360 |
|
361 |
|
362 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
|
|
199 |
|
200 |
with gr.Column():
|
201 |
self.train_memory_options = gr.Markdown(interactive=False,
|
202 |
+
value='Performance and VRAM usage optimizations, may not work on all devices:')
|
203 |
with gr.Row():
|
204 |
self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=True)
|
205 |
self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
|
206 |
self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
|
207 |
+
self.train_use_gradient_checkpointing_input = gr.Checkbox(
|
208 |
+
label="Gradient checkpointing", value=False)
|
209 |
|
210 |
with gr.Column(scale=1):
|
211 |
|
|
|
249 |
)
|
250 |
|
251 |
with gr.Column(scale=1):
|
252 |
+
self.export_status = gr.Button(
|
253 |
+
value='', variant='primary', label='Status', interactive=False)
|
254 |
self.export_button = gr.Button(
|
255 |
+
value="Export")
|
|
|
256 |
|
257 |
self.infr_button.click(self.inference, inputs = [
|
258 |
self.prompt_input_infr,
|
|
|
290 |
self.save_path_input_export,
|
291 |
self.save_half_export
|
292 |
],
|
293 |
+
outputs=[self.export_status]
|
294 |
)
|
295 |
|
296 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
297 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
298 |
+
seed=-1,
|
299 |
pbar = gr.Progress(track_tqdm=True)):
|
300 |
|
301 |
if self.training:
|
|
|
332 |
try:
|
333 |
self.training = True
|
334 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
335 |
+
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing, seed=int(seed))
|
336 |
finally:
|
337 |
self.training = False
|
338 |
|
|
|
357 |
with finetuner:
|
358 |
if save_half:
|
359 |
diffuser = diffuser.half()
|
360 |
+
diffuser.pipeline.to('cpu', torch_dtype=torch.float16)
|
361 |
diffuser.pipeline.save_pretrained(save_path)
|
362 |
+
return [gr.update(value=f'Done! Your model is at {save_path}.')]
|
363 |
|
364 |
|
365 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
train.py
CHANGED
@@ -52,7 +52,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
52 |
|
53 |
if seed == -1:
|
54 |
seed = random.randint(0, 2 ** 30)
|
55 |
-
set_seed(seed)
|
56 |
|
57 |
for i in pbar:
|
58 |
with torch.no_grad():
|
|
|
52 |
|
53 |
if seed == -1:
|
54 |
seed = random.randint(0, 2 ** 30)
|
55 |
+
set_seed(int(seed))
|
56 |
|
57 |
for i in pbar:
|
58 |
with torch.no_grad():
|