Spaces:
Runtime error
Runtime error
Memory saving
Browse files
train.py
CHANGED
@@ -10,6 +10,9 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
|
|
10 |
diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
|
11 |
diffuser.train()
|
12 |
|
|
|
|
|
|
|
13 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
14 |
|
15 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
@@ -22,7 +25,11 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
|
|
22 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
23 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
|
27 |
for i in pbar:
|
28 |
|
@@ -61,8 +68,11 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
|
|
61 |
neutral_latents.requires_grad = False
|
62 |
|
63 |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
|
|
|
|
|
|
|
|
|
64 |
loss.backward()
|
65 |
-
losses.append(loss.item())
|
66 |
optimizer.step()
|
67 |
|
68 |
torch.save(finetuner.state_dict(), save_path)
|
|
|
10 |
diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
|
11 |
diffuser.train()
|
12 |
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
|
17 |
|
18 |
optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
|
|
|
25 |
neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
|
26 |
positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
|
27 |
|
28 |
+
del diffuser.vae
|
29 |
+
del diffuser.text_encoder
|
30 |
+
del diffuser.tokenizer
|
31 |
+
|
32 |
+
torch.cuda.empty_cache()
|
33 |
|
34 |
for i in pbar:
|
35 |
|
|
|
68 |
neutral_latents.requires_grad = False
|
69 |
|
70 |
loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
|
71 |
+
|
72 |
+
del negative_latents, neutral_latents, positive_latents, latents_steps, latents
|
73 |
+
torch.cuda.empty_cache()
|
74 |
+
|
75 |
loss.backward()
|
|
|
76 |
optimizer.step()
|
77 |
|
78 |
torch.save(finetuner.state_dict(), save_path)
|