Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
•
6067469
1
Parent(s):
c8aa68b
save every N steps and loss logging
Browse files- StableDiffuser.py +3 -3
- app.py +40 -18
- isolate_rng.py +73 -0
- train.py +22 -2
StableDiffuser.py
CHANGED
@@ -95,8 +95,8 @@ class StableDiffuser(torch.nn.Module):
|
|
95 |
def set_scheduler_timesteps(self, n_steps):
|
96 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
97 |
|
98 |
-
def get_initial_latents(self, n_imgs,
|
99 |
-
noise = self.get_noise(n_imgs,
|
100 |
latents = noise * self.scheduler.init_noise_sigma
|
101 |
return latents
|
102 |
|
@@ -199,7 +199,7 @@ class StableDiffuser(torch.nn.Module):
|
|
199 |
prompts = [prompts]
|
200 |
|
201 |
self.set_scheduler_timesteps(n_steps)
|
202 |
-
latents = self.get_initial_latents(n_imgs,
|
203 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
204 |
end_iteration = end_iteration or n_steps
|
205 |
latents_steps, trace_steps = self.diffusion(
|
|
|
95 |
def set_scheduler_timesteps(self, n_steps):
|
96 |
self.scheduler.set_timesteps(n_steps, device=self.unet.device)
|
97 |
|
98 |
+
def get_initial_latents(self, n_imgs, height, width, n_prompts, generator=None):
|
99 |
+
noise = self.get_noise(n_imgs, height, width, generator=generator).repeat(n_prompts, 1, 1, 1)
|
100 |
latents = noise * self.scheduler.init_noise_sigma
|
101 |
return latents
|
102 |
|
|
|
199 |
prompts = [prompts]
|
200 |
|
201 |
self.set_scheduler_timesteps(n_steps)
|
202 |
+
latents = self.get_initial_latents(n_imgs, height, width, len(prompts), generator=generator)
|
203 |
text_embeddings = self.get_text_embeddings(prompts,negative_prompts,n_imgs=n_imgs)
|
204 |
end_iteration = end_iteration or n_steps
|
205 |
latents_steps, trace_steps = self.diffusion(
|
app.py
CHANGED
@@ -10,22 +10,16 @@ from memory_efficiency import MemoryEfficiencyWrapper
|
|
10 |
from train import train
|
11 |
|
12 |
import os
|
13 |
-
model_map = {'Van Gogh': 'models/vangogh.pt',
|
14 |
-
'Pablo Picasso': 'models/pablopicasso.pt',
|
15 |
-
'Car': 'models/car.pt',
|
16 |
-
'Garbage Truck': 'models/garbagetruck.pt',
|
17 |
-
'French Horn': 'models/frenchhorn.pt',
|
18 |
-
'Kilian Eng': 'models/kilianeng.pt',
|
19 |
-
'Thomas Kinkade': 'models/thomaskinkade.pt',
|
20 |
-
'Tyler Edlin': 'models/tyleredlin.pt',
|
21 |
-
'Kelly McKernan': 'models/kellymckernan.pt',
|
22 |
-
'Rembrandt': 'models/rembrandt.pt' }
|
23 |
-
for model_file in os.listdir('models'):
|
24 |
-
path = 'models/' + model_file
|
25 |
-
if any([existing_path == path for existing_path in model_map.values()]):
|
26 |
-
continue
|
27 |
-
model_map[model_file] = path
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
|
31 |
SPACE_ID = os.getenv('SPACE_ID')
|
@@ -85,6 +79,10 @@ class Demo:
|
|
85 |
value='Van Gogh',
|
86 |
interactive=True
|
87 |
)
|
|
|
|
|
|
|
|
|
88 |
|
89 |
self.seed_infr = gr.Number(
|
90 |
label="Seed",
|
@@ -196,6 +194,11 @@ class Demo:
|
|
196 |
label="Seed",
|
197 |
info="Set to a fixed number for reproducible training results, or use -1 to pick randomly"
|
198 |
)
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
with gr.Column():
|
201 |
self.train_memory_options = gr.Markdown(interactive=False,
|
@@ -215,6 +218,10 @@ class Demo:
|
|
215 |
value="Train",
|
216 |
)
|
217 |
|
|
|
|
|
|
|
|
|
218 |
self.download = gr.Files()
|
219 |
|
220 |
with gr.Tab("Export") as export_column:
|
@@ -268,7 +275,10 @@ class Demo:
|
|
268 |
self.image_orig
|
269 |
]
|
270 |
)
|
271 |
-
self.
|
|
|
|
|
|
|
272 |
self.train_model_input,
|
273 |
self.train_img_size_input,
|
274 |
self.prompt_input,
|
@@ -281,9 +291,12 @@ class Demo:
|
|
281 |
self.train_use_amp_input,
|
282 |
self.train_use_gradient_checkpointing_input,
|
283 |
self.train_seed_input,
|
|
|
284 |
],
|
285 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
286 |
)
|
|
|
|
|
287 |
self.export_button.click(self.export, inputs = [
|
288 |
self.model_dropdown_export,
|
289 |
self.base_repo_id_or_path_input_export,
|
@@ -293,9 +306,15 @@ class Demo:
|
|
293 |
outputs=[self.export_status]
|
294 |
)
|
295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
297 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
298 |
-
seed=-1,
|
299 |
pbar = gr.Progress(track_tqdm=True)):
|
300 |
|
301 |
if self.training:
|
@@ -331,10 +350,13 @@ class Demo:
|
|
331 |
|
332 |
try:
|
333 |
self.training = True
|
|
|
334 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
335 |
-
|
|
|
336 |
finally:
|
337 |
self.training = False
|
|
|
338 |
|
339 |
torch.cuda.empty_cache()
|
340 |
|
|
|
10 |
from train import train
|
11 |
|
12 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
def populate_model_map():
|
15 |
+
model_map = {}
|
16 |
+
for model_file in os.listdir('models'):
|
17 |
+
path = 'models/' + model_file
|
18 |
+
if any([existing_path == path for existing_path in model_map.values()]):
|
19 |
+
continue
|
20 |
+
model_map[model_file] = path
|
21 |
+
return model_map
|
22 |
+
model_map = populate_model_map()
|
23 |
|
24 |
ORIGINAL_SPACE_ID = 'baulab/Erasing-Concepts-In-Diffusion'
|
25 |
SPACE_ID = os.getenv('SPACE_ID')
|
|
|
79 |
value='Van Gogh',
|
80 |
interactive=True
|
81 |
)
|
82 |
+
self.model_reload_button = gr.Button(
|
83 |
+
value="🔄",
|
84 |
+
interactive=True
|
85 |
+
)
|
86 |
|
87 |
self.seed_infr = gr.Number(
|
88 |
label="Seed",
|
|
|
194 |
label="Seed",
|
195 |
info="Set to a fixed number for reproducible training results, or use -1 to pick randomly"
|
196 |
)
|
197 |
+
self.train_save_every_input = gr.Number(
|
198 |
+
value=-1,
|
199 |
+
label="Save every N steps",
|
200 |
+
info="If >0, save the model throughout training at the given step interval."
|
201 |
+
)
|
202 |
|
203 |
with gr.Column():
|
204 |
self.train_memory_options = gr.Markdown(interactive=False,
|
|
|
218 |
value="Train",
|
219 |
)
|
220 |
|
221 |
+
self.train_cancel_button = gr.Button(
|
222 |
+
value="Cancel training"
|
223 |
+
)
|
224 |
+
|
225 |
self.download = gr.Files()
|
226 |
|
227 |
with gr.Tab("Export") as export_column:
|
|
|
275 |
self.image_orig
|
276 |
]
|
277 |
)
|
278 |
+
self.model_reload_button.click(self.reload_models,
|
279 |
+
inputs=[self.model_dropdown],
|
280 |
+
outputs=[self.model_dropdown])
|
281 |
+
train_event = self.train_button.click(self.train, inputs = [
|
282 |
self.train_model_input,
|
283 |
self.train_img_size_input,
|
284 |
self.prompt_input,
|
|
|
291 |
self.train_use_amp_input,
|
292 |
self.train_use_gradient_checkpointing_input,
|
293 |
self.train_seed_input,
|
294 |
+
self.train_save_every_input,
|
295 |
],
|
296 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
297 |
)
|
298 |
+
self.train_cancel_button.click(lambda x: print("cancel pressed"), cancels=[train_event])
|
299 |
+
|
300 |
self.export_button.click(self.export, inputs = [
|
301 |
self.model_dropdown_export,
|
302 |
self.base_repo_id_or_path_input_export,
|
|
|
306 |
outputs=[self.export_status]
|
307 |
)
|
308 |
|
309 |
+
def reload_models(self, model_dropdown):
|
310 |
+
current_model_name = model_dropdown
|
311 |
+
global model_map
|
312 |
+
model_map = populate_model_map()
|
313 |
+
return [gr.Dropdown.update(choices=list(model_map.keys()), value=current_model_name)]
|
314 |
+
|
315 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
316 |
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
317 |
+
seed=-1, save_every=-1,
|
318 |
pbar = gr.Progress(track_tqdm=True)):
|
319 |
|
320 |
if self.training:
|
|
|
350 |
|
351 |
try:
|
352 |
self.training = True
|
353 |
+
self.train_cancel_button.update(interactive=True)
|
354 |
train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
|
355 |
+
use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing,
|
356 |
+
seed=int(seed), save_every=int(save_every))
|
357 |
finally:
|
358 |
self.training = False
|
359 |
+
self.train_cancel_button.update(interactive=False)
|
360 |
|
361 |
torch.cuda.empty_cache()
|
362 |
|
isolate_rng.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# copy/pasted from pytorch lightning
|
2 |
+
# https://github.com/Lightning-AI/lightning/blob/0d52f4577310b5a1624bed4d23d49e37fb05af9e/src/lightning_fabric/utilities/seed.py
|
3 |
+
# and
|
4 |
+
# https://github.com/Lightning-AI/lightning/blob/98f7696d1681974d34fad59c03b4b58d9524ed13/src/pytorch_lightning/utilities/seed.py
|
5 |
+
|
6 |
+
# Copyright The Lightning team.
|
7 |
+
#
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
#
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
#
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
|
20 |
+
from contextlib import contextmanager
|
21 |
+
from typing import Generator, Dict, Any
|
22 |
+
|
23 |
+
import torch
|
24 |
+
import numpy as np
|
25 |
+
from random import getstate as python_get_rng_state
|
26 |
+
from random import setstate as python_set_rng_state
|
27 |
+
|
28 |
+
|
29 |
+
def _collect_rng_states(include_cuda: bool = True) -> Dict[str, Any]:
|
30 |
+
"""Collect the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python."""
|
31 |
+
states = {
|
32 |
+
"torch": torch.get_rng_state(),
|
33 |
+
"numpy": np.random.get_state(),
|
34 |
+
"python": python_get_rng_state(),
|
35 |
+
}
|
36 |
+
if include_cuda:
|
37 |
+
states["torch.cuda"] = torch.cuda.get_rng_state_all()
|
38 |
+
return states
|
39 |
+
|
40 |
+
|
41 |
+
def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
|
42 |
+
"""Set the global random state of :mod:`torch`, :mod:`torch.cuda`, :mod:`numpy` and Python in the current
|
43 |
+
process."""
|
44 |
+
torch.set_rng_state(rng_state_dict["torch"])
|
45 |
+
# torch.cuda rng_state is only included since v1.8.
|
46 |
+
if "torch.cuda" in rng_state_dict:
|
47 |
+
torch.cuda.set_rng_state_all(rng_state_dict["torch.cuda"])
|
48 |
+
np.random.set_state(rng_state_dict["numpy"])
|
49 |
+
version, state, gauss = rng_state_dict["python"]
|
50 |
+
python_set_rng_state((version, tuple(state), gauss))
|
51 |
+
|
52 |
+
|
53 |
+
@contextmanager
|
54 |
+
def isolate_rng(include_cuda: bool = True) -> Generator[None, None, None]:
|
55 |
+
"""A context manager that resets the global random state on exit to what it was before entering.
|
56 |
+
It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators.
|
57 |
+
Args:
|
58 |
+
include_cuda: Whether to allow this function to also control the `torch.cuda` random number generator.
|
59 |
+
Set this to ``False`` when using the function in a forked process where CUDA re-initialization is
|
60 |
+
prohibited.
|
61 |
+
Example:
|
62 |
+
>>> import torch
|
63 |
+
>>> torch.manual_seed(1) # doctest: +ELLIPSIS
|
64 |
+
<torch._C.Generator object at ...>
|
65 |
+
>>> with isolate_rng():
|
66 |
+
... [torch.rand(1) for _ in range(3)]
|
67 |
+
[tensor([0.7576]), tensor([0.2793]), tensor([0.4031])]
|
68 |
+
>>> torch.rand(1)
|
69 |
+
tensor([0.7576])
|
70 |
+
"""
|
71 |
+
states = _collect_rng_states(include_cuda)
|
72 |
+
yield
|
73 |
+
_set_rng_states(states)
|
train.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
|
3 |
from accelerate.utils import set_seed
|
4 |
from torch.cuda.amp import autocast
|
@@ -8,11 +8,12 @@ from finetuning import FineTunedModel
|
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
|
|
|
11 |
from memory_efficiency import MemoryEfficiencyWrapper
|
12 |
|
13 |
|
14 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
15 |
-
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1):
|
16 |
|
17 |
nsteps = 50
|
18 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
@@ -54,6 +55,9 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
54 |
seed = random.randint(0, 2 ** 30)
|
55 |
set_seed(int(seed))
|
56 |
|
|
|
|
|
|
|
57 |
for i in pbar:
|
58 |
with torch.no_grad():
|
59 |
diffuser.set_scheduler_timesteps(nsteps)
|
@@ -92,6 +96,22 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
92 |
memory_efficiency_wrapper.step(optimizer, loss)
|
93 |
optimizer.zero_grad()
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
torch.save(finetuner.state_dict(), save_path)
|
96 |
|
97 |
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
|
|
|
1 |
+
import random
|
2 |
|
3 |
from accelerate.utils import set_seed
|
4 |
from torch.cuda.amp import autocast
|
|
|
8 |
import torch
|
9 |
from tqdm import tqdm
|
10 |
|
11 |
+
from isolate_rng import isolate_rng
|
12 |
from memory_efficiency import MemoryEfficiencyWrapper
|
13 |
|
14 |
|
15 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
16 |
+
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False, seed=-1, save_every=-1):
|
17 |
|
18 |
nsteps = 50
|
19 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
|
|
55 |
seed = random.randint(0, 2 ** 30)
|
56 |
set_seed(int(seed))
|
57 |
|
58 |
+
prev_losses = []
|
59 |
+
start_loss = None
|
60 |
+
max_prev_loss_count = 10
|
61 |
for i in pbar:
|
62 |
with torch.no_grad():
|
63 |
diffuser.set_scheduler_timesteps(nsteps)
|
|
|
96 |
memory_efficiency_wrapper.step(optimizer, loss)
|
97 |
optimizer.zero_grad()
|
98 |
|
99 |
+
# print moving average loss
|
100 |
+
prev_losses.append(loss.detach().clone())
|
101 |
+
if len(prev_losses) > max_prev_loss_count:
|
102 |
+
prev_losses.pop(0)
|
103 |
+
if start_loss is None:
|
104 |
+
start_loss = prev_losses[-1]
|
105 |
+
if len(prev_losses) >= max_prev_loss_count:
|
106 |
+
moving_average_loss = sum(prev_losses) / len(prev_losses)
|
107 |
+
print(
|
108 |
+
f"step {i}: loss={loss.item()} (avg={moving_average_loss.item()}, start ∆={(moving_average_loss - start_loss).item()}")
|
109 |
+
else:
|
110 |
+
print(f"step {i}: loss={loss.item()}")
|
111 |
+
|
112 |
+
if save_every > 0 and ((i % save_every) == (save_every-1)):
|
113 |
+
torch.save(finetuner.state_dict(), save_path + f"__step_{i}.pt")
|
114 |
+
|
115 |
torch.save(finetuner.state_dict(), save_path)
|
116 |
|
117 |
del diffuser, loss, optimizer, finetuner, negative_latents, neutral_latents, positive_latents, latents_steps, latents
|