File size: 7,395 Bytes
1a635ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc5980d
1a635ad
 
 
 
 
 
 
 
 
 
 
4641216
cc5980d
afc4d61
4641216
cc5980d
afc4d61
1a635ad
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import open_clip
import gradio as gr
import numpy as np
import torch
import torchvision

from tqdm.auto import tqdm
from PIL import Image, ImageColor
from torchvision import transforms
from diffusers import DDIMScheduler, DDPMPipeline


device = (
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)

# Load the pretrained pipeline
pipeline_name = "alkzar90/sd-class-ukiyo-e-256"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)

# Sample some images with a DDIM Scheduler over 40 steps
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)


# Color guidance
#-------------------------------------------------------------------------------
# Color guidance function
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
    """Given a target color (R, G, B) return a loss for how far away on average
    the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
    target = (
        torch.tensor(target_color).to(images.device) * 2 - 1
    )  # Map target color to (-1, 1)
    target = target[
        None, :, None, None
    ]  # Get shape right to work with the images (b, c, h, w)
    error = torch.abs(
        images - target
    ).mean()  # Mean absolute difference between the image pixels and the target color
    return error


    
# CLIP guidance
#-------------------------------------------------------------------------------
clip_model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-B-32", pretrained="openai"
)
clip_model.to(device)

# Transforms to resize and augment an image + normalize to match CLIP's training data
tfms = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),  # Random CROP each time
        transforms.RandomAffine(
            5
        ),  # One possible random augmentation: skews the image
        transforms.RandomHorizontalFlip(),  # You can add additional augmentations if you like
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


# CLIP guidance function
def clip_loss(image, text_features):
    image_features = clip_model.encode_image(
        tfms(image)
    )  # Note: applies the above transforms
    input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
    embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
    dists = (
        input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
    )  # Squared Great Circle Distance
    return dists.mean()    



#  Sample generator loop
#-------------------------------------------------------------------------------
def generate(color, 
             color_loss_scale, 
             num_examples=4, 
             seed=None, 
             prompt=None,
             prompt_loss_scale=None,
             prompt_n_cuts=None,
             inference_steps=50,
             ):
    scheduler.set_timesteps(num_inference_steps=inference_steps)

    if seed:
      torch.manual_seed(seed)

    if prompt:
      text = open_clip.tokenize([prompt]).to(device)
      with torch.no_grad(), torch.cuda.amp.autocast():
        text_features = clip_model.encode_text(text)

    target_color = ImageColor.getcolor(color, "RGB")  # Target color as RGB
    target_color = [a / 255 for a in target_color]  # Rescale from (0, 255) to (0, 1)

    x = torch.randn(num_examples, 3, 256, 256).to(device)

    for i, t in tqdm(enumerate(scheduler.timesteps)):
        model_input = scheduler.scale_model_input(x, t)
        with torch.no_grad():
            noise_pred = image_pipe.unet(model_input, t)["sample"]
        x = x.detach().requires_grad_()
        x0 = scheduler.step(noise_pred, t, x).pred_original_sample

        # color loss
        loss = color_loss(x0, target_color) * color_loss_scale
        cond_color_grad = -torch.autograd.grad(loss, x)[0]
        # Modify x based solely on the color gradient -> x_cond
        x_cond = x.detach() + cond_color_grad
        
        # prompt loss (modify x_cond with cond_prompt_grad) based on
        # the original x (not modifified previously with cond_color_grad)
        if prompt:
          cond_prompt_grad = 0
          for cut in range(prompt_n_cuts):
            # Set requires grad on x
            x = x.detach().requires_grad_()
            # Get the predicted x0:
            x0 = scheduler.step(noise_pred, t, x).pred_original_sample
            # Calculate loss
            prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale
            # Get gradient (scale by n_cuts since we want the average)
            cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts
          # Modify x based on this gradient
          alpha_bar = scheduler.alphas_cumprod[i]
          x_cond = (
              x_cond + cond_prompt_grad * alpha_bar.sqrt()
          )  # Note the additional scaling factor here!


        x = scheduler.step(noise_pred, t, x_cond).prev_sample
    grid = torchvision.utils.make_grid(x, nrow=4)
    im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
    im = Image.fromarray(np.array(im * 255).astype(np.uint8))
    im.save("test.jpeg")
    return im

                 

# GRADIO Interface
#-------------------------------------------------------------------------------                
TITLE="Ukiyo-e postal generator service 🎴!"
DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo"
CSS = ".output-image, .input-image, .image-preview {height: 250px !important}"

# See the gradio docs for the types of inputs and outputs available
inputs = [
    gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"),  # Add any inputs you need here
    gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7),
    gr.Slider(label="num_examples (# images generated)", minimum=4, maximum=12, value=8, step=4),
    gr.Number(label="seed (reproducibility and experimentation)", value=666),
    gr.Text(label="Text prompt (optional)", value=None),
    gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10),
    gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4),
    gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", minimum=40, maximum=60, value=40, step=1),
]

outputs = gr.Image(label="result")

# And the minimal interface
demo = gr.Interface(
    fn=generate,
    inputs=inputs,
    outputs=outputs,
    css=CSS,
    examples=[
        #["#DF5C16", 6.7, 12, 666, None, None, None, 40],
        #["#C01660", 13.5, 12, 1990, None, None, None, 40],
        #["#44CCAA", 8.9, 12, 1512, None, None, None, 40], 
        ["#39A291", 5.0, 8, 666, "A sakura tree", 60, 4, 52],
        #["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52],
        #["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 4, 47],
    ],
    title=TITLE,
    description=DESCRIPTION,
)

if __name__ == "__main__":
    demo.launch(enable_queue=True)