lllyasviel commited on
Commit
827a7ed
·
1 Parent(s): dbf3ae5
Files changed (2) hide show
  1. modules/core.py +64 -0
  2. modules/default_pipeline.py +6 -10
modules/core.py CHANGED
@@ -142,6 +142,70 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sa
142
  return out
143
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  @torch.no_grad()
146
  def image_to_numpy(x):
147
  return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
 
142
  return out
143
 
144
 
145
+ @torch.no_grad()
146
+ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent,
147
+ seed=None, steps=30, refiner_switch_step=20, cfg=9.0, sampler_name='dpmpp_2m_sde',
148
+ scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
149
+ force_full_denoise=False):
150
+ seed = seed if isinstance(seed, int) else random.randint(1, 2 ** 64)
151
+
152
+ device = comfy.model_management.get_torch_device()
153
+ latent_image = latent["samples"]
154
+
155
+ if disable_noise:
156
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
157
+ else:
158
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
159
+ noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
160
+
161
+ noise_mask = None
162
+ if "noise_mask" in latent:
163
+ noise_mask = latent["noise_mask"]
164
+
165
+ previewer = get_previewer(device, model.model.latent_format)
166
+
167
+ pbar = comfy.utils.ProgressBar(steps)
168
+
169
+ def callback(step, x0, x, total_steps):
170
+ if previewer and step % 3 == 0:
171
+ previewer.preview(x0, step, total_steps)
172
+ pbar.update_absolute(step + 1, total_steps, None)
173
+
174
+ sigmas = None
175
+ disable_pbar = False
176
+
177
+ if noise_mask is not None:
178
+ noise_mask = prepare_mask(noise_mask, noise.shape, device)
179
+
180
+ comfy.model_management.load_model_gpu(model)
181
+ real_model = model.model
182
+
183
+ noise = noise.to(device)
184
+ latent_image = latent_image.to(device)
185
+
186
+ positive_copy = broadcast_cond(positive, noise.shape[0], device)
187
+ negative_copy = broadcast_cond(negative, noise.shape[0], device)
188
+
189
+ models = load_additional_models(positive, negative, model.model_dtype())
190
+
191
+ sampler = KSamplerWithRefiner(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
192
+ denoise=denoise, model_options=model.model_options)
193
+
194
+ samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
195
+ start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise,
196
+ denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar,
197
+ seed=seed)
198
+
199
+ samples = samples.cpu()
200
+
201
+ cleanup_additional_models(models)
202
+
203
+ out = latent.copy()
204
+ out["samples"] = samples
205
+
206
+ return out
207
+
208
+
209
  @torch.no_grad()
210
  def image_to_numpy(x):
211
  return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
modules/default_pipeline.py CHANGED
@@ -23,20 +23,16 @@ def process(positive_prompt, negative_prompt, width=1024, height=1024, batch_siz
23
 
24
  empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=batch_size)
25
 
26
- sampled_latent = core.ksampler(
27
  model=xl_base.unet,
28
  positive=positive_conditions,
29
  negative=negative_conditions,
 
 
 
 
30
  latent=empty_latent,
31
- steps=30, start_step=0, last_step=20, disable_noise=False, force_full_denoise=False
32
- )
33
-
34
- sampled_latent = core.ksampler(
35
- model=xl_refiner.unet,
36
- positive=positive_conditions_refiner,
37
- negative=negative_conditions_refiner,
38
- latent=sampled_latent,
39
- steps=30, start_step=20, last_step=30, disable_noise=True, force_full_denoise=True
40
  )
41
 
42
  decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent)
 
23
 
24
  empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=batch_size)
25
 
26
+ sampled_latent = core.ksampler_with_refiner(
27
  model=xl_base.unet,
28
  positive=positive_conditions,
29
  negative=negative_conditions,
30
+ refiner=xl_refiner,
31
+ refiner_positive=positive_conditions_refiner,
32
+ refiner_negative=negative_conditions_refiner,
33
+ refiner_switch_step=20,
34
  latent=empty_latent,
35
+ steps=30, start_step=0, last_step=30, disable_noise=False, force_full_denoise=True
 
 
 
 
 
 
 
 
36
  )
37
 
38
  decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent)