lllyasviel commited on
Commit
5928536
·
1 Parent(s): 44d1f2e
Files changed (1) hide show
  1. modules/sd.py +46 -34
modules/sd.py CHANGED
@@ -1,50 +1,62 @@
1
- import os
2
  import random
3
  import torch
4
  import numpy as np
5
 
6
  from comfy.sd import load_checkpoint_guess_config
7
  from nodes import VAEDecode, KSamplerAdvanced, EmptyLatentImage, CLIPTextEncode
8
- from modules.path import modelfile_path
9
 
10
 
11
- xl_base_filename = os.path.join(modelfile_path, 'sd_xl_base_1.0.safetensors')
12
- xl_refiner_filename = os.path.join(modelfile_path, 'sd_xl_refiner_1.0.safetensors')
13
-
14
- xl_base, xl_base_clip, xl_base_vae, xl_base_clipvision = load_checkpoint_guess_config(xl_base_filename)
15
- del xl_base_clipvision
16
-
17
  opCLIPTextEncode = CLIPTextEncode()
18
  opEmptyLatentImage = EmptyLatentImage()
19
  opKSamplerAdvanced = KSamplerAdvanced()
20
  opVAEDecode = VAEDecode()
21
 
22
- with torch.no_grad():
23
- positive_conditions = opCLIPTextEncode.encode(clip=xl_base_clip, text='a handsome man in forest')[0]
24
- negative_conditions = opCLIPTextEncode.encode(clip=xl_base_clip, text='bad, ugly')[0]
25
-
26
- initial_latent_image = opEmptyLatentImage.generate(width=1024, height=1024, batch_size=1)[0]
27
-
28
- samples = opKSamplerAdvanced.sample(
29
- add_noise="enable",
30
- noise_seed=random.randint(1, 2 ** 64),
31
- steps=25,
32
- cfg=9,
33
- sampler_name="euler",
34
- scheduler="normal",
35
- start_at_step=0,
36
- end_at_step=25,
37
- return_with_leftover_noise="enable",
38
- model=xl_base,
39
- positive=positive_conditions,
40
- negative=negative_conditions,
41
- latent_image=initial_latent_image,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )[0]
43
 
44
- vae_decoded = opVAEDecode.decode(samples=samples, vae=xl_base_vae)[0]
45
 
46
- for image in vae_decoded:
47
- i = 255. * image.cpu().numpy()
48
- img = np.clip(i, 0, 255).astype(np.uint8)
49
- import cv2
50
- cv2.imwrite('a.png', img[:, :, ::-1])
 
 
1
  import random
2
  import torch
3
  import numpy as np
4
 
5
  from comfy.sd import load_checkpoint_guess_config
6
  from nodes import VAEDecode, KSamplerAdvanced, EmptyLatentImage, CLIPTextEncode
 
7
 
8
 
 
 
 
 
 
 
9
  opCLIPTextEncode = CLIPTextEncode()
10
  opEmptyLatentImage = EmptyLatentImage()
11
  opKSamplerAdvanced = KSamplerAdvanced()
12
  opVAEDecode = VAEDecode()
13
 
14
+
15
+ class StableDiffusionModel:
16
+ def __init__(self, unet, vae, clip, clip_vision):
17
+ self.unet = unet
18
+ self.vae = vae
19
+ self.clip = clip
20
+ self.clip_vision = clip_vision
21
+
22
+
23
+ @torch.no_grad()
24
+ def load_model(ckpt_filename):
25
+ unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename)
26
+ return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
27
+
28
+
29
+ @torch.no_grad()
30
+ def encode_prompt_condition(clip, prompt):
31
+ return opCLIPTextEncode.encode(clip=clip, text=prompt)[0]
32
+
33
+
34
+ @torch.no_grad()
35
+ def decode_vae(vae, latent_image):
36
+ return opVAEDecode.decode(samples=latent_image, vae=vae)[0]
37
+
38
+
39
+ @torch.no_grad()
40
+ def ksample(model, positive_condition, negative_condition, latent_image, add_noise=True, noise_seed=None, steps=25, cfg=9,
41
+ sampler_name='euler_ancestral', scheduler='normal', start_at_step=None, end_at_step=None,
42
+ return_with_leftover_noise=False):
43
+ return opKSamplerAdvanced.sample(
44
+ add_noise='enable' if add_noise else 'disable',
45
+ noise_seed=noise_seed if isinstance(noise_seed, int) else random.randint(1, 2 ** 64),
46
+ steps=steps,
47
+ cfg=cfg,
48
+ sampler_name=sampler_name,
49
+ scheduler=scheduler,
50
+ start_at_step=0 if start_at_step is None else start_at_step,
51
+ end_at_step=steps if end_at_step is None else end_at_step,
52
+ return_with_leftover_noise='enable' if return_with_leftover_noise else 'disable',
53
+ model=model,
54
+ positive=positive_condition,
55
+ negative=negative_condition,
56
+ latent_image=latent_image,
57
  )[0]
58
 
 
59
 
60
+ @torch.no_grad()
61
+ def image_to_numpy(x):
62
+ return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]