Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
•
ab11bdd
1
Parent(s):
fc73e59
actually use AMP=3x speedup
Browse files- StableDiffuser.py +4 -8
- app.py +23 -12
- train.py +13 -12
StableDiffuser.py
CHANGED
@@ -4,7 +4,6 @@ import torch
|
|
4 |
from baukit import TraceDict
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
from PIL import Image
|
7 |
-
from torch.cuda.amp import GradScaler
|
8 |
from tqdm.auto import tqdm
|
9 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
10 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
@@ -35,6 +34,7 @@ class StableDiffuser(torch.nn.Module):
|
|
35 |
|
36 |
def __init__(self,
|
37 |
scheduler='LMS',
|
|
|
38 |
repo_id_or_path="CompVis/stable-diffusion-v1-4"):
|
39 |
|
40 |
super().__init__()
|
@@ -46,6 +46,7 @@ class StableDiffuser(torch.nn.Module):
|
|
46 |
self.tokenizer = self.pipeline.tokenizer
|
47 |
self.text_encoder = self.pipeline.text_encoder
|
48 |
self.safety_checker = self.pipeline.safety_checker
|
|
|
49 |
|
50 |
if scheduler == 'LMS':
|
51 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
@@ -55,10 +56,8 @@ class StableDiffuser(torch.nn.Module):
|
|
55 |
self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
|
56 |
|
57 |
self.eval()
|
58 |
-
|
59 |
-
|
60 |
-
def feature_extractor(self):
|
61 |
-
return self.pipeline.feature_extractor
|
62 |
|
63 |
def get_noise(self, batch_size, width, height, generator=None):
|
64 |
param = list(self.parameters())[0]
|
@@ -226,9 +225,6 @@ class StableDiffuser(torch.nn.Module):
|
|
226 |
|
227 |
return images_steps
|
228 |
|
229 |
-
def save_pretrained(self, path, **kwargs):
|
230 |
-
self.pipeline.save_pretrained(path, **kwargs)
|
231 |
-
|
232 |
|
233 |
if __name__ == '__main__':
|
234 |
|
|
|
4 |
from baukit import TraceDict
|
5 |
from diffusers import StableDiffusionPipeline
|
6 |
from PIL import Image
|
|
|
7 |
from tqdm.auto import tqdm
|
8 |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
9 |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
|
|
34 |
|
35 |
def __init__(self,
|
36 |
scheduler='LMS',
|
37 |
+
keep_pipeline=False,
|
38 |
repo_id_or_path="CompVis/stable-diffusion-v1-4"):
|
39 |
|
40 |
super().__init__()
|
|
|
46 |
self.tokenizer = self.pipeline.tokenizer
|
47 |
self.text_encoder = self.pipeline.text_encoder
|
48 |
self.safety_checker = self.pipeline.safety_checker
|
49 |
+
self.feature_extractor = self.pipeline.feature_extractor
|
50 |
|
51 |
if scheduler == 'LMS':
|
52 |
self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
|
|
56 |
self.scheduler = DDPMScheduler.from_pretrained(repo_id_or_path, subfolder="scheduler")
|
57 |
|
58 |
self.eval()
|
59 |
+
if not keep_pipeline:
|
60 |
+
del self.pipeline
|
|
|
|
|
61 |
|
62 |
def get_noise(self, batch_size, width, height, generator=None):
|
63 |
param = list(self.parameters())[0]
|
|
|
225 |
|
226 |
return images_steps
|
227 |
|
|
|
|
|
|
|
228 |
|
229 |
if __name__ == '__main__':
|
230 |
|
app.py
CHANGED
@@ -162,9 +162,9 @@ class Demo:
|
|
162 |
info="Prompt corresponding to concept to erase"
|
163 |
)
|
164 |
|
165 |
-
choices = ['ESD-x', 'ESD-self']
|
166 |
-
if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40 or is_xformers_available():
|
167 |
-
|
168 |
|
169 |
self.train_method_input = gr.Dropdown(
|
170 |
choices=choices,
|
@@ -274,7 +274,7 @@ class Demo:
|
|
274 |
self.train_use_amp_input,
|
275 |
#self.train_use_gradient_checkpointing_input
|
276 |
],
|
277 |
-
outputs=[self.train_button,
|
278 |
)
|
279 |
self.export_button.click(self.export, inputs = [
|
280 |
self.model_dropdown_export,
|
@@ -286,12 +286,19 @@ class Demo:
|
|
286 |
)
|
287 |
|
288 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
289 |
-
use_adamw8bit=True, use_xformers=
|
290 |
pbar = gr.Progress(track_tqdm=True)):
|
291 |
|
292 |
if self.training:
|
293 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
if train_method == 'ESD-x':
|
296 |
modules = ".*attn2$"
|
297 |
frozen = []
|
@@ -319,20 +326,24 @@ class Demo:
|
|
319 |
new_model_name = f'*new* {os.path.basename(save_path)}'
|
320 |
model_map[new_model_name] = save_path
|
321 |
|
322 |
-
return [gr.update(interactive=True, value='Train'),
|
323 |
-
'Try your model ({new_model_name}) in the "Test" tab'),
|
|
|
324 |
gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
|
325 |
|
326 |
def export(self, model_name, base_repo_id_or_path, save_path, save_half):
|
327 |
model_path = model_map[model_name]
|
328 |
checkpoint = torch.load(model_path)
|
329 |
-
|
330 |
-
|
|
|
|
|
|
|
331 |
with finetuner:
|
332 |
if save_half:
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
|
337 |
|
338 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
|
|
162 |
info="Prompt corresponding to concept to erase"
|
163 |
)
|
164 |
|
165 |
+
choices = ['ESD-x', 'ESD-self', 'ESD-u']
|
166 |
+
#if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40 or is_xformers_available():
|
167 |
+
# choices.append('ESD-u')
|
168 |
|
169 |
self.train_method_input = gr.Dropdown(
|
170 |
choices=choices,
|
|
|
274 |
self.train_use_amp_input,
|
275 |
#self.train_use_gradient_checkpointing_input
|
276 |
],
|
277 |
+
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
278 |
)
|
279 |
self.export_button.click(self.export, inputs = [
|
280 |
self.model_dropdown_export,
|
|
|
286 |
)
|
287 |
|
288 |
def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
|
289 |
+
use_adamw8bit=True, use_xformers=False, use_amp=False, use_gradient_checkpointing=False,
|
290 |
pbar = gr.Progress(track_tqdm=True)):
|
291 |
|
292 |
if self.training:
|
293 |
return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
|
294 |
|
295 |
+
print(f"Training {repo_id_or_path} at {img_size} to remove '{prompt}'.")
|
296 |
+
print(f" {train_method}, negative guidance {neg_guidance}, lr {lr}, {iterations} iterations.")
|
297 |
+
print(f" {'✅' if use_gradient_checkpointing else '❌'} gradient checkpointing")
|
298 |
+
print(f" {'✅' if use_amp else '❌'} AMP")
|
299 |
+
print(f" {'✅' if use_xformers else '❌'} xformers")
|
300 |
+
print(f" {'✅' if use_adamw8bit else '❌'} 8-bit AdamW")
|
301 |
+
|
302 |
if train_method == 'ESD-x':
|
303 |
modules = ".*attn2$"
|
304 |
frozen = []
|
|
|
326 |
new_model_name = f'*new* {os.path.basename(save_path)}'
|
327 |
model_map[new_model_name] = save_path
|
328 |
|
329 |
+
return [gr.update(interactive=True, value='Train'),
|
330 |
+
gr.update(value=f'Done Training! Try your model ({new_model_name}) in the "Test" tab'),
|
331 |
+
save_path,
|
332 |
gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
|
333 |
|
334 |
def export(self, model_name, base_repo_id_or_path, save_path, save_half):
|
335 |
model_path = model_map[model_name]
|
336 |
checkpoint = torch.load(model_path)
|
337 |
+
diffuser = StableDiffuser(scheduler='DDIM',
|
338 |
+
keep_pipeline=True,
|
339 |
+
repo_id_or_path=base_repo_id_or_path
|
340 |
+
).eval()
|
341 |
+
finetuner = FineTunedModel.from_checkpoint(diffuser, checkpoint).eval()
|
342 |
with finetuner:
|
343 |
if save_half:
|
344 |
+
diffuser = diffuser.half()
|
345 |
+
diffuser.pipeline.to(torch.float16, torch_device=diffuser.device)
|
346 |
+
diffuser.pipeline.save_pretrained(save_path)
|
347 |
|
348 |
|
349 |
def inference(self, prompt, negative_prompt, seed, width, height, model_name, base_repo_id_or_path, pbar = gr.Progress(track_tqdm=True)):
|
train.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from StableDiffuser import StableDiffuser
|
2 |
from finetuning import FineTunedModel
|
3 |
import torch
|
@@ -8,20 +10,17 @@ from memory_efficiency import MemoryEfficiencyWrapper
|
|
8 |
|
9 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
10 |
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
|
11 |
-
|
12 |
-
nsteps = 50
|
13 |
|
|
|
14 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
15 |
|
16 |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
|
17 |
use_gradient_checkpointing=use_gradient_checkpointing )
|
18 |
with memory_efficiency_wrapper:
|
19 |
-
|
20 |
diffuser.train()
|
21 |
-
|
22 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
23 |
-
|
24 |
if use_adamw8bit:
|
|
|
25 |
import bitsandbytes as bnb
|
26 |
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
|
27 |
lr=lr,
|
@@ -30,13 +29,13 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
30 |
eps=1e-8
|
31 |
)
|
32 |
else:
|
|
|
33 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
34 |
criteria = torch.nn.MSELoss()
|
35 |
|
36 |
pbar = tqdm(range(iterations))
|
37 |
|
38 |
with torch.no_grad():
|
39 |
-
|
40 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
41 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
42 |
|
@@ -56,7 +55,7 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
56 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
57 |
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
|
58 |
|
59 |
-
with finetuner:
|
60 |
latents_steps, _ = diffuser.diffusion(
|
61 |
latents,
|
62 |
positive_text_embeddings,
|
@@ -67,19 +66,21 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
67 |
)
|
68 |
|
69 |
diffuser.set_scheduler_timesteps(1000)
|
70 |
-
|
71 |
iteration = int(iteration / nsteps * 1000)
|
72 |
|
73 |
-
|
74 |
-
|
|
|
75 |
|
76 |
with finetuner:
|
77 |
-
|
|
|
78 |
|
79 |
positive_latents.requires_grad = False
|
80 |
neutral_latents.requires_grad = False
|
81 |
|
82 |
-
|
|
|
83 |
memory_efficiency_wrapper.step(optimizer, loss)
|
84 |
optimizer.step()
|
85 |
|
|
|
1 |
+
from torch.cuda.amp import autocast
|
2 |
+
|
3 |
from StableDiffuser import StableDiffuser
|
4 |
from finetuning import FineTunedModel
|
5 |
import torch
|
|
|
10 |
|
11 |
def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
|
12 |
use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
|
|
|
|
|
13 |
|
14 |
+
nsteps = 50
|
15 |
diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
|
16 |
|
17 |
memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
|
18 |
use_gradient_checkpointing=use_gradient_checkpointing )
|
19 |
with memory_efficiency_wrapper:
|
|
|
20 |
diffuser.train()
|
|
|
21 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
|
|
22 |
if use_adamw8bit:
|
23 |
+
use print("using AdamW 8Bit optimizer")
|
24 |
import bitsandbytes as bnb
|
25 |
optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
|
26 |
lr=lr,
|
|
|
29 |
eps=1e-8
|
30 |
)
|
31 |
else:
|
32 |
+
print("using Adam optimizer")
|
33 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
34 |
criteria = torch.nn.MSELoss()
|
35 |
|
36 |
pbar = tqdm(range(iterations))
|
37 |
|
38 |
with torch.no_grad():
|
|
|
39 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
40 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
41 |
|
|
|
55 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
56 |
latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
|
57 |
|
58 |
+
with autocast(enabled=use_amp), finetuner:
|
59 |
latents_steps, _ = diffuser.diffusion(
|
60 |
latents,
|
61 |
positive_text_embeddings,
|
|
|
66 |
)
|
67 |
|
68 |
diffuser.set_scheduler_timesteps(1000)
|
|
|
69 |
iteration = int(iteration / nsteps * 1000)
|
70 |
|
71 |
+
with autocast(enabled=use_amp):
|
72 |
+
positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
73 |
+
neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
|
74 |
|
75 |
with finetuner:
|
76 |
+
with autocast(enabled=use_amp):
|
77 |
+
negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
|
78 |
|
79 |
positive_latents.requires_grad = False
|
80 |
neutral_latents.requires_grad = False
|
81 |
|
82 |
+
# loss = criteria(e_n, e_0) works the best try 5000 epochs
|
83 |
+
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents)))
|
84 |
memory_efficiency_wrapper.step(optimizer, loss)
|
85 |
optimizer.step()
|
86 |
|