lllyasviel commited on
Commit
cd7cecf
·
1 Parent(s): 5afc367
Files changed (1) hide show
  1. modules/samplers_advanced.py +59 -38
modules/samplers_advanced.py CHANGED
@@ -23,8 +23,8 @@ class KSamplerAdvanced:
23
  sampler = self.SAMPLERS[0]
24
  self.scheduler = scheduler
25
  self.sampler = sampler
26
- self.sigma_min=float(self.model_wrap.sigma_min)
27
- self.sigma_max=float(self.model_wrap.sigma_max)
28
  self.set_steps(steps, denoise)
29
  self.denoise = denoise
30
  self.model_options = model_options
@@ -40,7 +40,8 @@ class KSamplerAdvanced:
40
  if self.scheduler == "karras":
41
  sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
42
  elif self.scheduler == "exponential":
43
- sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
 
44
  elif self.scheduler == "normal":
45
  sigmas = self.model_wrap.get_sigmas(steps)
46
  elif self.scheduler == "simple":
@@ -59,11 +60,12 @@ class KSamplerAdvanced:
59
  if denoise is None or denoise > 0.9999:
60
  self.sigmas = self.calculate_sigmas(steps).to(self.device)
61
  else:
62
- new_steps = int(steps/denoise)
63
  sigmas = self.calculate_sigmas(new_steps).to(self.device)
64
  self.sigmas = sigmas[-(steps + 1):]
65
 
66
- def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None, force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
 
67
  if sigmas is None:
68
  sigmas = self.sigmas
69
  sigma_min = self.sigma_min
@@ -92,7 +94,7 @@ class KSamplerAdvanced:
92
  calculate_start_end_timesteps(self.model_wrap, negative)
93
  calculate_start_end_timesteps(self.model_wrap, positive)
94
 
95
- #make sure each cond area has an opposite one with the same area
96
  for c in positive:
97
  create_cond_with_same_area_if_none(negative, c)
98
  for c in negative:
@@ -100,30 +102,36 @@ class KSamplerAdvanced:
100
 
101
  pre_run_control(self.model_wrap, negative + positive)
102
 
103
- apply_empty_x_to_equal_area(list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
 
 
104
  apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
105
 
106
  if self.model.is_adm():
107
- positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "positive")
108
- negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative")
 
 
109
 
110
  if latent_image is not None:
111
  latent_image = self.model.process_latent_in(latent_image)
112
 
113
- extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options, "seed":seed}
 
114
 
115
  cond_concat = None
116
- if hasattr(self.model, 'concat_keys'): #inpaint
117
  cond_concat = []
118
  for ck in self.model.concat_keys:
119
  if denoise_mask is not None:
120
  if ck == "mask":
121
- cond_concat.append(denoise_mask[:,:1])
122
  elif ck == "masked_image":
123
- cond_concat.append(latent_image) #NOTE: the latent_image should be masked by the mask in pixel space
 
124
  else:
125
  if ck == "mask":
126
- cond_concat.append(torch.ones_like(noise)[:,:1])
127
  elif ck == "masked_image":
128
  cond_concat.append(blank_inpaint_image_like(noise))
129
  extra_args["cond_concat"] = cond_concat
@@ -133,11 +141,16 @@ class KSamplerAdvanced:
133
  else:
134
  max_denoise = True
135
 
136
-
137
  if self.sampler == "uni_pc":
138
- samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar)
 
 
 
139
  elif self.sampler == "uni_pc_bh2":
140
- samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=max_denoise, extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar)
 
 
 
141
  elif self.sampler == "ddim":
142
  timesteps = []
143
  for s in range(sigmas.shape[0]):
@@ -153,24 +166,26 @@ class KSamplerAdvanced:
153
 
154
  sampler = DDIMSampler(self.model, device=self.device)
155
  sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
156
- z_enc = sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device), noise=noise, max_denoise=max_denoise)
 
 
157
  samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
158
- conditioning=positive,
159
- batch_size=noise.shape[0],
160
- shape=noise.shape[1:],
161
- verbose=False,
162
- unconditional_guidance_scale=cfg,
163
- unconditional_conditioning=negative,
164
- eta=0.0,
165
- x_T=z_enc,
166
- x0=latent_image,
167
- img_callback=ddim_callback,
168
- denoise_function=self.model_wrap.predict_eps_discrete_timestep,
169
- extra_args=extra_args,
170
- mask=noise_mask,
171
- to_zero=sigmas[-1]==0,
172
- end_step=sigmas.shape[0] - 1,
173
- disable_pbar=disable_pbar)
174
 
175
  else:
176
  extra_args["denoise_mask"] = denoise_mask
@@ -190,11 +205,17 @@ class KSamplerAdvanced:
190
  if latent_image is not None:
191
  noise += latent_image
192
  if self.sampler == "dpm_fast":
193
- samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
 
 
194
  elif self.sampler == "dpm_adaptive":
195
- samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar)
 
 
196
  else:
197
- samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar)
 
 
 
198
 
199
  return self.model.process_latent_out(samples.to(torch.float32))
200
-
 
23
  sampler = self.SAMPLERS[0]
24
  self.scheduler = scheduler
25
  self.sampler = sampler
26
+ self.sigma_min = float(self.model_wrap.sigma_min)
27
+ self.sigma_max = float(self.model_wrap.sigma_max)
28
  self.set_steps(steps, denoise)
29
  self.denoise = denoise
30
  self.model_options = model_options
 
40
  if self.scheduler == "karras":
41
  sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
42
  elif self.scheduler == "exponential":
43
+ sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=self.sigma_min,
44
+ sigma_max=self.sigma_max)
45
  elif self.scheduler == "normal":
46
  sigmas = self.model_wrap.get_sigmas(steps)
47
  elif self.scheduler == "simple":
 
60
  if denoise is None or denoise > 0.9999:
61
  self.sigmas = self.calculate_sigmas(steps).to(self.device)
62
  else:
63
+ new_steps = int(steps / denoise)
64
  sigmas = self.calculate_sigmas(new_steps).to(self.device)
65
  self.sigmas = sigmas[-(steps + 1):]
66
 
67
+ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=None, last_step=None,
68
+ force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
69
  if sigmas is None:
70
  sigmas = self.sigmas
71
  sigma_min = self.sigma_min
 
94
  calculate_start_end_timesteps(self.model_wrap, negative)
95
  calculate_start_end_timesteps(self.model_wrap, positive)
96
 
97
+ # make sure each cond area has an opposite one with the same area
98
  for c in positive:
99
  create_cond_with_same_area_if_none(negative, c)
100
  for c in negative:
 
102
 
103
  pre_run_control(self.model_wrap, negative + positive)
104
 
105
+ apply_empty_x_to_equal_area(
106
+ list(filter(lambda c: c[1].get('control_apply_to_uncond', False) == True, positive)), negative, 'control',
107
+ lambda cond_cnets, x: cond_cnets[x])
108
  apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
109
 
110
  if self.model.is_adm():
111
+ positive = encode_adm(self.model, positive, noise.shape[0], noise.shape[3], noise.shape[2], self.device,
112
+ "positive")
113
+ negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device,
114
+ "negative")
115
 
116
  if latent_image is not None:
117
  latent_image = self.model.process_latent_in(latent_image)
118
 
119
+ extra_args = {"cond": positive, "uncond": negative, "cond_scale": cfg, "model_options": self.model_options,
120
+ "seed": seed}
121
 
122
  cond_concat = None
123
+ if hasattr(self.model, 'concat_keys'): # inpaint
124
  cond_concat = []
125
  for ck in self.model.concat_keys:
126
  if denoise_mask is not None:
127
  if ck == "mask":
128
+ cond_concat.append(denoise_mask[:, :1])
129
  elif ck == "masked_image":
130
+ cond_concat.append(
131
+ latent_image) # NOTE: the latent_image should be masked by the mask in pixel space
132
  else:
133
  if ck == "mask":
134
+ cond_concat.append(torch.ones_like(noise)[:, :1])
135
  elif ck == "masked_image":
136
  cond_concat.append(blank_inpaint_image_like(noise))
137
  extra_args["cond_concat"] = cond_concat
 
141
  else:
142
  max_denoise = True
143
 
 
144
  if self.sampler == "uni_pc":
145
+ samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas,
146
+ sampling_function=sampling_function, max_denoise=max_denoise,
147
+ extra_args=extra_args, noise_mask=denoise_mask, callback=callback,
148
+ disable=disable_pbar)
149
  elif self.sampler == "uni_pc_bh2":
150
+ samples = uni_pc.sample_unipc(self.model_wrap, noise, latent_image, sigmas,
151
+ sampling_function=sampling_function, max_denoise=max_denoise,
152
+ extra_args=extra_args, noise_mask=denoise_mask, callback=callback,
153
+ variant='bh2', disable=disable_pbar)
154
  elif self.sampler == "ddim":
155
  timesteps = []
156
  for s in range(sigmas.shape[0]):
 
166
 
167
  sampler = DDIMSampler(self.model, device=self.device)
168
  sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
169
+ z_enc = sampler.stochastic_encode(latent_image,
170
+ torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device),
171
+ noise=noise, max_denoise=max_denoise)
172
  samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
173
+ conditioning=positive,
174
+ batch_size=noise.shape[0],
175
+ shape=noise.shape[1:],
176
+ verbose=False,
177
+ unconditional_guidance_scale=cfg,
178
+ unconditional_conditioning=negative,
179
+ eta=0.0,
180
+ x_T=z_enc,
181
+ x0=latent_image,
182
+ img_callback=ddim_callback,
183
+ denoise_function=self.model_wrap.predict_eps_discrete_timestep,
184
+ extra_args=extra_args,
185
+ mask=noise_mask,
186
+ to_zero=sigmas[-1] == 0,
187
+ end_step=sigmas.shape[0] - 1,
188
+ disable_pbar=disable_pbar)
189
 
190
  else:
191
  extra_args["denoise_mask"] = denoise_mask
 
205
  if latent_image is not None:
206
  noise += latent_image
207
  if self.sampler == "dpm_fast":
208
+ samples = k_diffusion_sampling.sample_dpm_fast(self.model_k, noise, sigma_min, sigmas[0], total_steps,
209
+ extra_args=extra_args, callback=k_callback,
210
+ disable=disable_pbar)
211
  elif self.sampler == "dpm_adaptive":
212
+ samples = k_diffusion_sampling.sample_dpm_adaptive(self.model_k, noise, sigma_min, sigmas[0],
213
+ extra_args=extra_args, callback=k_callback,
214
+ disable=disable_pbar)
215
  else:
216
+ samples = getattr(k_diffusion_sampling, "sample_{}".format(self.sampler))(self.model_k, noise, sigmas,
217
+ extra_args=extra_args,
218
+ callback=k_callback,
219
+ disable=disable_pbar)
220
 
221
  return self.model.process_latent_out(samples.to(torch.float32))