Damian Stewart commited on
Commit
fc73e59
1 Parent(s): d8ffb68

wip adding AMP and xformers to training code path

Browse files
Files changed (6) hide show
  1. StableDiffuser.py +8 -8
  2. app.py +32 -16
  3. finetuning.py +0 -11
  4. memory_efficiency.py +86 -0
  5. requirements.txt +5 -2
  6. train.py +63 -46
StableDiffuser.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from baukit import TraceDict
5
  from diffusers import StableDiffusionPipeline
6
  from PIL import Image
 
7
  from tqdm.auto import tqdm
8
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
9
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
@@ -34,18 +35,17 @@ class StableDiffuser(torch.nn.Module):
34
 
35
  def __init__(self,
36
  scheduler='LMS',
37
- repo_id_or_path="CompVis/stable-diffusion-v1-4",
38
- variant='fp16'
39
- ):
40
 
41
  super().__init__()
42
 
43
- self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path, variant=variant)
44
 
45
  self.vae = self.pipeline.vae
46
  self.unet = self.pipeline.unet
47
  self.tokenizer = self.pipeline.tokenizer
48
  self.text_encoder = self.pipeline.text_encoder
 
49
 
50
  if scheduler == 'LMS':
51
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
@@ -57,8 +57,8 @@ class StableDiffuser(torch.nn.Module):
57
  self.eval()
58
 
59
  @property
60
- def safety_checker(self):
61
- return self.pipeline.safety_checker
62
 
63
  def get_noise(self, batch_size, width, height, generator=None):
64
  param = list(self.parameters())[0]
@@ -215,9 +215,9 @@ class StableDiffuser(torch.nn.Module):
215
  self.safety_checker = self.safety_checker.float()
216
  safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
217
  image, has_nsfw_concept = self.safety_checker(
218
- images=latents_steps[i].float().cpu().numpy(), clip_input=safety_checker_input.pixel_values.float()
219
  )
220
- images_steps[i][0] = self.to_image(torch.from_numpy(image))[0]
221
 
222
  images_steps = list(zip(*images_steps))
223
 
 
4
  from baukit import TraceDict
5
  from diffusers import StableDiffusionPipeline
6
  from PIL import Image
7
+ from torch.cuda.amp import GradScaler
8
  from tqdm.auto import tqdm
9
  from diffusers.schedulers.scheduling_ddim import DDIMScheduler
10
  from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
 
35
 
36
  def __init__(self,
37
  scheduler='LMS',
38
+ repo_id_or_path="CompVis/stable-diffusion-v1-4"):
 
 
39
 
40
  super().__init__()
41
 
42
+ self.pipeline = StableDiffusionPipeline.from_pretrained(repo_id_or_path)
43
 
44
  self.vae = self.pipeline.vae
45
  self.unet = self.pipeline.unet
46
  self.tokenizer = self.pipeline.tokenizer
47
  self.text_encoder = self.pipeline.text_encoder
48
+ self.safety_checker = self.pipeline.safety_checker
49
 
50
  if scheduler == 'LMS':
51
  self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
 
57
  self.eval()
58
 
59
  @property
60
+ def feature_extractor(self):
61
+ return self.pipeline.feature_extractor
62
 
63
  def get_noise(self, batch_size, width, height, generator=None):
64
  param = list(self.parameters())[0]
 
215
  self.safety_checker = self.safety_checker.float()
216
  safety_checker_input = self.feature_extractor(images_steps[i], return_tensors="pt").to(latents_steps[0].device)
217
  image, has_nsfw_concept = self.safety_checker(
218
+ images=latents_steps[i], clip_input=safety_checker_input.pixel_values.float()
219
  )
220
+ images_steps[i][0] = self.to_image(image)[0]
221
 
222
  images_steps = list(zip(*images_steps))
223
 
app.py CHANGED
@@ -1,8 +1,12 @@
1
  import gradio as gr
2
  import torch
3
  import os
 
 
 
4
  from finetuning import FineTunedModel
5
  from StableDiffuser import StableDiffuser
 
6
  from train import train
7
 
8
  import os
@@ -158,8 +162,8 @@ class Demo:
158
  info="Prompt corresponding to concept to erase"
159
  )
160
 
161
- choices = ['ESD-x']
162
- if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40:
163
  choices.append('ESD-u')
164
 
165
  self.train_method_input = gr.Dropdown(
@@ -188,6 +192,12 @@ class Demo:
188
  info='Learning rate used to train'
189
  )
190
 
 
 
 
 
 
 
191
  with gr.Column(scale=1):
192
 
193
  self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
@@ -258,7 +268,11 @@ class Demo:
258
  self.train_method_input,
259
  self.neg_guidance_input,
260
  self.iterations_input,
261
- self.lr_input
 
 
 
 
262
  ],
263
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
264
  )
@@ -271,41 +285,43 @@ class Demo:
271
  outputs=[self.export_button]
272
  )
273
 
274
- def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)):
 
 
275
 
276
  if self.training:
277
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
278
 
279
  if train_method == 'ESD-x':
280
-
281
  modules = ".*attn2$"
282
  frozen = []
283
 
284
  elif train_method == 'ESD-u':
285
-
286
  modules = "unet$"
287
  frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"]
288
 
289
  elif train_method == 'ESD-self':
290
-
291
  modules = ".*attn1$"
292
  frozen = []
293
 
294
  randn = torch.randint(1, 10000000, (1,)).item()
295
 
296
- save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}.pt"
297
-
298
- self.training = True
299
-
300
- train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path)
301
-
302
- self.training = False
303
 
304
  torch.cuda.empty_cache()
305
 
306
- model_map['Custom'] = save_path
 
307
 
308
- return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom model in the "Test" tab'), save_path, gr.Dropdown.update(choices=list(model_map.keys()), value='Custom')]
 
 
309
 
310
  def export(self, model_name, base_repo_id_or_path, save_path, save_half):
311
  model_path = model_map[model_name]
 
1
  import gradio as gr
2
  import torch
3
  import os
4
+
5
+ from diffusers.utils import is_xformers_available
6
+
7
  from finetuning import FineTunedModel
8
  from StableDiffuser import StableDiffuser
9
+ from memory_efficiency import MemoryEfficiencyWrapper
10
  from train import train
11
 
12
  import os
 
162
  info="Prompt corresponding to concept to erase"
163
  )
164
 
165
+ choices = ['ESD-x', 'ESD-self']
166
+ if torch.cuda.get_device_properties(0).total_memory * 1e-9 >= 40 or is_xformers_available():
167
  choices.append('ESD-u')
168
 
169
  self.train_method_input = gr.Dropdown(
 
192
  info='Learning rate used to train'
193
  )
194
 
195
+ with gr.Row():
196
+ self.train_use_adamw8bit_input = gr.Checkbox(label="8bit AdamW", value=False)
197
+ self.train_use_xformers_input = gr.Checkbox(label="xformers", value=True)
198
+ self.train_use_amp_input = gr.Checkbox(label="AMP", value=True)
199
+ #self.train_use_gradient_checkpointing_input = gr.Checkbox(label="Gradient checkpointing", value=True)
200
+
201
  with gr.Column(scale=1):
202
 
203
  self.train_status = gr.Button(value='', variant='primary', label='Status', interactive=False)
 
268
  self.train_method_input,
269
  self.neg_guidance_input,
270
  self.iterations_input,
271
+ self.lr_input,
272
+ self.train_use_adamw8bit_input,
273
+ self.train_use_xformers_input,
274
+ self.train_use_amp_input,
275
+ #self.train_use_gradient_checkpointing_input
276
  ],
277
  outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
278
  )
 
285
  outputs=[self.export_button]
286
  )
287
 
288
+ def train(self, repo_id_or_path, img_size, prompt, train_method, neg_guidance, iterations, lr,
289
+ use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=True,
290
+ pbar = gr.Progress(track_tqdm=True)):
291
 
292
  if self.training:
293
  return [gr.update(interactive=True, value='Train'), gr.update(value='Someone else is training... Try again soon'), None, gr.update()]
294
 
295
  if train_method == 'ESD-x':
 
296
  modules = ".*attn2$"
297
  frozen = []
298
 
299
  elif train_method == 'ESD-u':
 
300
  modules = "unet$"
301
  frozen = [".*attn2$", "unet.time_embedding$", "unet.conv_out$"]
302
 
303
  elif train_method == 'ESD-self':
 
304
  modules = ".*attn1$"
305
  frozen = []
306
 
307
  randn = torch.randint(1, 10000000, (1,)).item()
308
 
309
+ save_path = f"models/{randn}_{prompt.lower().replace(' ', '')}_{train_method}_ng{neg_guidance}_lr{lr}_iter{iterations}.pt"
310
+ try:
311
+ self.training = True
312
+ train(repo_id_or_path, img_size, prompt, modules, frozen, iterations, neg_guidance, lr, save_path,
313
+ use_adamw8bit, use_xformers, use_amp, use_gradient_checkpointing)
314
+ finally:
315
+ self.training = False
316
 
317
  torch.cuda.empty_cache()
318
 
319
+ new_model_name = f'*new* {os.path.basename(save_path)}'
320
+ model_map[new_model_name] = save_path
321
 
322
+ return [gr.update(interactive=True, value='Train'), gr.update(value=f'Done Training! \n '
323
+ 'Try your model ({new_model_name}) in the "Test" tab'), save_path,
324
+ gr.Dropdown.update(choices=list(model_map.keys()), value=new_model_name)]
325
 
326
  def export(self, model_name, base_repo_id_or_path, save_path, save_half):
327
  model_path = model_map[model_name]
finetuning.py CHANGED
@@ -51,7 +51,6 @@ class FineTunedModel(torch.nn.Module):
51
 
52
  @classmethod
53
  def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
54
-
55
  if isinstance(checkpoint, str):
56
  checkpoint = torch.load(checkpoint)
57
 
@@ -64,33 +63,23 @@ class FineTunedModel(torch.nn.Module):
64
 
65
 
66
  def __enter__(self):
67
-
68
  for key, ft_module in self.ft_modules.items():
69
  util.set_module(self.model, key, ft_module)
70
 
71
  def __exit__(self, exc_type, exc_value, tb):
72
-
73
  for key, module in self.orig_modules.items():
74
  util.set_module(self.model, key, module)
75
 
76
  def parameters(self):
77
-
78
  parameters = []
79
-
80
  for ft_module in self.ft_modules.values():
81
-
82
  parameters.extend(list(ft_module.parameters()))
83
-
84
  return parameters
85
 
86
  def state_dict(self):
87
-
88
  state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
89
-
90
  return state_dict
91
 
92
  def load_state_dict(self, state_dict):
93
-
94
  for key, sd in state_dict.items():
95
-
96
  self.ft_modules[key].load_state_dict(sd)
 
51
 
52
  @classmethod
53
  def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
 
54
  if isinstance(checkpoint, str):
55
  checkpoint = torch.load(checkpoint)
56
 
 
63
 
64
 
65
  def __enter__(self):
 
66
  for key, ft_module in self.ft_modules.items():
67
  util.set_module(self.model, key, ft_module)
68
 
69
  def __exit__(self, exc_type, exc_value, tb):
 
70
  for key, module in self.orig_modules.items():
71
  util.set_module(self.model, key, module)
72
 
73
  def parameters(self):
 
74
  parameters = []
 
75
  for ft_module in self.ft_modules.values():
 
76
  parameters.extend(list(ft_module.parameters()))
 
77
  return parameters
78
 
79
  def state_dict(self):
 
80
  state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
 
81
  return state_dict
82
 
83
  def load_state_dict(self, state_dict):
 
84
  for key, sd in state_dict.items():
 
85
  self.ft_modules[key].load_state_dict(sd)
memory_efficiency.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from EveryDream2Trainer
2
+ import contextlib
3
+ import traceback
4
+
5
+ import torch
6
+ from torch.cuda.amp import GradScaler
7
+
8
+ from StableDiffuser import StableDiffuser
9
+
10
+
11
+ class MemoryEfficiencyWrapper:
12
+
13
+ def __init__(self,
14
+ diffuser: StableDiffuser,
15
+ use_amp: bool,
16
+ use_xformers: bool,
17
+ use_gradient_checkpointing: bool):
18
+ self.diffuser = diffuser
19
+ self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == [8, 8, 8, 8]
20
+ self.is_sd1attn = diffuser.unet.config["attention_head_dim"] == 8 or self.is_sd1attn
21
+
22
+ self.use_amp = use_amp
23
+ self.use_xformers = use_xformers
24
+ self.use_gradient_checkpointing = use_gradient_checkpointing
25
+
26
+ def __enter__(self):
27
+ if self.use_gradient_checkpointing:
28
+ self.diffuser.unet.enable_gradient_checkpointing()
29
+ self.diffuser.text_encoder.gradient_checkpointing_enable()
30
+
31
+ if self.use_xformers:
32
+ if (self.use_amp and self.is_sd1attn) or (not self.is_sd1attn):
33
+ try:
34
+ self.diffuser.unet.enable_xformers_memory_efficient_attention()
35
+ print("Enabled xformers")
36
+ except Exception as ex:
37
+ print("failed to load xformers, using attention slicing instead")
38
+ self.diffuser.unet.set_attention_slice("auto")
39
+ pass
40
+ elif (not self.amp and self.is_sd1attn):
41
+ print("AMP is disabled but model is SD1.X, using attention slicing instead of xformers")
42
+ self.diffuser.unet.set_attention_slice("auto")
43
+ else:
44
+ print("xformers disabled via arg, using attention slicing instead")
45
+ self.diffuser.unet.set_attention_slice("auto")
46
+
47
+ self.diffuser.vae = self.diffuser.vae.to(self.diffuser.vae.device, dtype=torch.float16 if self.use_amp else torch.float32)
48
+ self.diffuser.unet = self.diffuser.unet.to(self.diffuser.unet.device, dtype=torch.float32)
49
+
50
+ try:
51
+ # unet = torch.compile(unet)
52
+ # text_encoder = torch.compile(text_encoder)
53
+ # vae = torch.compile(vae)
54
+ torch.set_float32_matmul_precision('high')
55
+ torch.backends.cudnn.allow_tf32 = True
56
+ # logging.info("Successfully compiled models")
57
+ except Exception as ex:
58
+ print(f"Failed to compile model, continuing anyway, ex: {ex}")
59
+ pass
60
+
61
+ self.grad_scaler = GradScaler(
62
+ enabled=self.use_amp,
63
+ init_scale=2 ** 17.5,
64
+ growth_factor=2,
65
+ backoff_factor=1.0 / 2,
66
+ growth_interval=25,
67
+ )
68
+
69
+ def step(self, optimizer, loss):
70
+ self.grad_scaler.scale(loss).backward()
71
+ self.grad_scaler.step(optimizer)
72
+ self.grad_scaler.update()
73
+
74
+ def __exit__(self, exc_type, exc_value, tb):
75
+ if exc_type is not None:
76
+ traceback.print_exception(exc_type, exc_value, tb)
77
+ # return False # uncomment to pass exception through):
78
+ self.diffuser.unet.disable_gradient_checkpointing()
79
+ try:
80
+ self.diffuser.text_encoder.gradient_checkpointing_disable()
81
+ except AttributeError:
82
+ # self.diffuser.text_encoder is likely `del`eted
83
+ pass
84
+
85
+ self.diffuser.unet.disable_xformers_memory_efficient_attention()
86
+ self.diffuser.unet.set_attention_slice("auto")
requirements.txt CHANGED
@@ -1,8 +1,11 @@
1
  gradio
2
- torch==1.13.1 --index-url https://download.pytorch.org/whl/cu118
3
- torchvision==0.14.1 --index-url https://download.pytorch.org/whl/cu118
4
  diffusers
5
  transformers
6
  accelerate
7
  scipy
8
  git+https://github.com/davidbau/baukit.git
 
 
 
 
1
  gradio
2
+ torch --index-url https://download.pytorch.org/whl/cu118
3
+ torchvision --index-url https://download.pytorch.org/whl/cu118
4
  diffusers
5
  transformers
6
  accelerate
7
  scipy
8
  git+https://github.com/davidbau/baukit.git
9
+ xformers
10
+ bitsandbytes==0.38.1
11
+ safetensors
train.py CHANGED
@@ -3,68 +3,85 @@ from finetuning import FineTunedModel
3
  import torch
4
  from tqdm import tqdm
5
 
6
- def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path):
 
 
 
 
7
 
8
  nsteps = 50
9
 
10
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
11
- diffuser.train()
12
 
13
- finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
 
 
14
 
15
- optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
16
- criteria = torch.nn.MSELoss()
17
 
18
- pbar = tqdm(range(iterations))
19
 
20
- with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
21
 
22
- neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
23
- positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
24
 
25
- del diffuser.vae
26
- del diffuser.text_encoder
27
- del diffuser.tokenizer
28
 
29
- torch.cuda.empty_cache()
 
30
 
31
- print(f"using img_size of {img_size}")
 
 
32
 
33
- for i in pbar:
34
- with torch.no_grad():
35
- diffuser.set_scheduler_timesteps(nsteps)
36
- optimizer.zero_grad()
 
 
 
 
37
 
38
- iteration = torch.randint(1, nsteps - 1, (1,)).item()
39
- latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  with finetuner:
42
- latents_steps, _ = diffuser.diffusion(
43
- latents,
44
- positive_text_embeddings,
45
- start_iteration=0,
46
- end_iteration=iteration,
47
- guidance_scale=3,
48
- show_progress=False
49
- )
50
-
51
- diffuser.set_scheduler_timesteps(1000)
52
-
53
- iteration = int(iteration / nsteps * 1000)
54
-
55
- positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
56
- neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
57
-
58
- with finetuner:
59
- negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
60
-
61
- positive_latents.requires_grad = False
62
- neutral_latents.requires_grad = False
63
-
64
- loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
65
-
66
- loss.backward()
67
- optimizer.step()
68
 
69
  torch.save(finetuner.state_dict(), save_path)
70
 
 
3
  import torch
4
  from tqdm import tqdm
5
 
6
+ from memory_efficiency import MemoryEfficiencyWrapper
7
+
8
+
9
+ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations, negative_guidance, lr, save_path,
10
+ use_adamw8bit=True, use_xformers=True, use_amp=True, use_gradient_checkpointing=False):
11
 
12
  nsteps = 50
13
 
14
  diffuser = StableDiffuser(scheduler='DDIM', repo_id_or_path=repo_id_or_path).to('cuda')
 
15
 
16
+ memory_efficiency_wrapper = MemoryEfficiencyWrapper(diffuser=diffuser, use_amp=use_amp, use_xformers=use_xformers,
17
+ use_gradient_checkpointing=use_gradient_checkpointing )
18
+ with memory_efficiency_wrapper:
19
 
20
+ diffuser.train()
 
21
 
22
+ finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
23
 
24
+ if use_adamw8bit:
25
+ import bitsandbytes as bnb
26
+ optimizer = bnb.optim.AdamW8bit(finetuner.parameters(),
27
+ lr=lr,
28
+ betas=(0.9, 0.999),
29
+ weight_decay=0.010,
30
+ eps=1e-8
31
+ )
32
+ else:
33
+ optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
34
+ criteria = torch.nn.MSELoss()
35
 
36
+ pbar = tqdm(range(iterations))
 
37
 
38
+ with torch.no_grad():
 
 
39
 
40
+ neutral_text_embeddings = diffuser.get_text_embeddings([''],n_imgs=1)
41
+ positive_text_embeddings = diffuser.get_text_embeddings([prompt],n_imgs=1)
42
 
43
+ del diffuser.vae
44
+ del diffuser.text_encoder
45
+ del diffuser.tokenizer
46
 
47
+ torch.cuda.empty_cache()
48
+
49
+ print(f"using img_size of {img_size}")
50
+
51
+ for i in pbar:
52
+ with torch.no_grad():
53
+ diffuser.set_scheduler_timesteps(nsteps)
54
+ optimizer.zero_grad()
55
 
56
+ iteration = torch.randint(1, nsteps - 1, (1,)).item()
57
+ latents = diffuser.get_initial_latents(1, width=img_size, height=img_size, n_prompts=1)
58
+
59
+ with finetuner:
60
+ latents_steps, _ = diffuser.diffusion(
61
+ latents,
62
+ positive_text_embeddings,
63
+ start_iteration=0,
64
+ end_iteration=iteration,
65
+ guidance_scale=3,
66
+ show_progress=False
67
+ )
68
+
69
+ diffuser.set_scheduler_timesteps(1000)
70
+
71
+ iteration = int(iteration / nsteps * 1000)
72
+
73
+ positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
74
+ neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
75
 
76
  with finetuner:
77
+ negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
78
+
79
+ positive_latents.requires_grad = False
80
+ neutral_latents.requires_grad = False
81
+
82
+ loss = criteria(negative_latents, neutral_latents - (negative_guidance*(positive_latents - neutral_latents))) #loss = criteria(e_n, e_0) works the best try 5000 epochs
83
+ memory_efficiency_wrapper.step(optimizer, loss)
84
+ optimizer.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  torch.save(finetuner.state_dict(), save_path)
87