File size: 3,977 Bytes
1a904ba
5bf4a08
 
1a904ba
ebfe47d
b4852a5
1a904ba
 
 
 
3313361
 
5bf4a08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3313361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a904ba
 
 
573dc98
b4852a5
ebfe47d
1a904ba
5bf4a08
 
 
 
 
 
 
 
 
 
 
6070c4a
5bf4a08
 
 
 
 
 
 
ebfe47d
 
 
 
0b72417
ebfe47d
 
 
 
 
 
 
 
 
 
 
 
1a904ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import math
import numpy as np
import torch
import gc
import safetensors.torch

from omegaconf import OmegaConf
from sgm.util import instantiate_from_config

from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler


def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))


def get_batch(keys, value_dict, N, device="cuda"):
    # Hardcoded demo setups; might undergo some changes in the future

    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "txt":
            batch["txt"] = (
                np.repeat([value_dict["prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
            batch_uc["txt"] = (
                np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
        elif key == "original_size_as_tuple":
            batch["original_size_as_tuple"] = (
                torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
                .to(device)
                .repeat(*N, 1)
            )
        elif key == "crop_coords_top_left":
            batch["crop_coords_top_left"] = (
                torch.tensor(
                    [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
                )
                .to(device)
                .repeat(*N, 1)
            )
        elif key == "aesthetic_score":
            batch["aesthetic_score"] = (
                torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
            )
            batch_uc["aesthetic_score"] = (
                torch.tensor([value_dict["negative_aesthetic_score"]])
                .to(device)
                .repeat(*N, 1)
            )

        elif key == "target_size_as_tuple":
            batch["target_size_as_tuple"] = (
                torch.tensor([value_dict["target_height"], value_dict["target_width"]])
                .to(device)
                .repeat(*N, 1)
            )
        else:
            batch[key] = value_dict[key]

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc


sampler = EulerAncestralSampler(
    num_steps=40,
    discretization_config={
        "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
    },
    guider_config={
        "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
        "params": {"scale": 9.0, "dyn_thresh_config": {
            "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
        }},
    },
    eta=1.0,
    s_noise=1.0,
    verbose=True,
)

config_path = './sd_xl_base.yaml'
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
model.to(torch.float16)
model.eval()
model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False)

model.conditioner.cuda()

value_dict = {
    "prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024,
    "crop_coords_top": 0, "crop_coords_left": 0, "target_height": 1024, "target_width": 1024, "aesthetic_score": 7.5,
    "negative_aesthetic_score": 2.0,
}

batch, batch_uc = get_batch(
    get_unique_embedder_keys_from_conditioner(model.conditioner),
    value_dict,
    [1],
)

c, uc = model.conditioner.get_unconditional_conditioning(
    batch,
    batch_uc=batch_uc)
model.conditioner.cpu()

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

shape = (1, 4, 128, 128)
randn = torch.randn(shape).to(torch.float16).cuda()


def denoiser(input, sigma, c):
    return model.denoiser(model.model, input, sigma, c)


model.model.cuda()
model.denoiser.cuda()
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
model.model.cpu()
model.denoiser.cpu()

a = 0