File size: 4,887 Bytes
1a904ba
5bf4a08
ab65a35
 
5bf4a08
1a904ba
ebfe47d
b4852a5
1a904ba
 
 
 
3313361
 
5bf4a08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3313361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a904ba
 
 
b4852a5
ebfe47d
1a904ba
69f57c0
 
9993090
5bf4a08
9993090
 
69f57c0
9993090
 
 
 
 
5bf4a08
9993090
 
 
 
 
ebfe47d
9993090
 
 
 
ebfe47d
9993090
 
ebfe47d
9993090
 
ebfe47d
9993090
 
ebfe47d
54ce9a5
9993090
 
ebfe47d
ab65a35
9993090
 
 
 
 
 
ab65a35
9993090
 
 
 
60fe69a
 
9993090
 
 
 
fbbf2aa
9993090
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import math

import einops
import numpy as np
import torch
import gc
import safetensors.torch

from omegaconf import OmegaConf
from sgm.util import instantiate_from_config

from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler


def get_unique_embedder_keys_from_conditioner(conditioner):
    return list(set([x.input_key for x in conditioner.embedders]))


def get_batch(keys, value_dict, N, device="cuda"):
    # Hardcoded demo setups; might undergo some changes in the future

    batch = {}
    batch_uc = {}

    for key in keys:
        if key == "txt":
            batch["txt"] = (
                np.repeat([value_dict["prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
            batch_uc["txt"] = (
                np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
                .reshape(N)
                .tolist()
            )
        elif key == "original_size_as_tuple":
            batch["original_size_as_tuple"] = (
                torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
                .to(device)
                .repeat(*N, 1)
            )
        elif key == "crop_coords_top_left":
            batch["crop_coords_top_left"] = (
                torch.tensor(
                    [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
                )
                .to(device)
                .repeat(*N, 1)
            )
        elif key == "aesthetic_score":
            batch["aesthetic_score"] = (
                torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
            )
            batch_uc["aesthetic_score"] = (
                torch.tensor([value_dict["negative_aesthetic_score"]])
                .to(device)
                .repeat(*N, 1)
            )

        elif key == "target_size_as_tuple":
            batch["target_size_as_tuple"] = (
                torch.tensor([value_dict["target_height"], value_dict["target_width"]])
                .to(device)
                .repeat(*N, 1)
            )
        else:
            batch[key] = value_dict[key]

    for key in batch.keys():
        if key not in batch_uc and isinstance(batch[key], torch.Tensor):
            batch_uc[key] = torch.clone(batch[key])
    return batch, batch_uc


sampler = EulerAncestralSampler(
    num_steps=40,
    discretization_config={
        "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
    },
    guider_config={
        "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
        "params": {"scale": 9.0, "dyn_thresh_config": {
            "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
        }},
    },
    eta=1.0,
    s_noise=1.0,
    verbose=True,
)

config_path = './sd_xl_base.yaml'
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
model.eval()
model.load_state_dict(safetensors.torch.load_file('./sd_xl_base_1.0.safetensors'), strict=False)

# model.conditioner.cuda()

with torch.no_grad():

    model.conditioner.embedders[0].device = 'cpu'
    model.conditioner.embedders[1].device = 'cpu'

    value_dict = {
        "prompt": "a handsome man in forest", "negative_prompt": "ugly, bad", "orig_height": 1024, "orig_width": 1024,
        "crop_coords_top": 0, "crop_coords_left": 0, "target_height": 1024, "target_width": 1024, "aesthetic_score": 7.5,
        "negative_aesthetic_score": 2.0,
    }

    batch, batch_uc = get_batch(
        get_unique_embedder_keys_from_conditioner(model.conditioner),
        value_dict,
        [1],
    )

    c, uc = model.conditioner.get_unconditional_conditioning(
        batch,
        batch_uc=batch_uc)
    # model.conditioner.cpu()

    c = {a: b.to(torch.float16) for a, b in c.items()}
    uc = {a: b.to(torch.float16) for a, b in uc.items()}

    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    shape = (1, 4, 128, 128)
    randn = torch.randn(shape).to(torch.float16).cuda()


    def denoiser(input, sigma, c):
        return model.denoiser(model.model, input, sigma, c)


    with torch.no_grad():
        model.model.to(torch.float16).cuda()
        model.denoiser.to(torch.float16).cuda()
        samples_z = sampler(denoiser, randn, cond=c, uc=uc)
        model.model.cpu()
        model.denoiser.cpu()

    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    with torch.no_grad():
        model.first_stage_model.cuda()
        samples_x = model.decode_first_stage(samples_z.float())
        samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
        model.first_stage_model.cpu()

    import cv2
    samples = einops.rearrange(samples, 'b c h w -> b h w c')[0, :, :, ::-1] * 255.0
    samples = samples.cpu().numpy().clip(0, 255).astype(np.uint8)
    cv2.imwrite('img.png', samples)