Spaces:
Runtime error
Runtime error
File size: 7,399 Bytes
1a635ad 36b5e4b 1a635ad c62fc67 1a635ad cc5980d 1a635ad c62fc67 4641216 cc5980d afc4d61 c62fc67 cc5980d afc4d61 c62fc67 1a635ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import open_clip
import gradio as gr
import numpy as np
import torch
import torchvision
from tqdm.auto import tqdm
from PIL import Image, ImageColor
from torchvision import transforms
from diffusers import DDIMScheduler, DDPMPipeline
device = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
# Load the pretrained pipeline
pipeline_name = "alkzar90/sd-class-ukiyo-e-256"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
# Sample some images with a DDIM Scheduler over 40 steps
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)
# Color guidance
#-------------------------------------------------------------------------------
# Color guidance function
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
"""Given a target color (R, G, B) return a loss for how far away on average
the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""
target = (
torch.tensor(target_color).to(images.device) * 2 - 1
) # Map target color to (-1, 1)
target = target[
None, :, None, None
] # Get shape right to work with the images (b, c, h, w)
error = torch.abs(
images - target
).mean() # Mean absolute difference between the image pixels and the target color
return error
# CLIP guidance
#-------------------------------------------------------------------------------
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
clip_model.to(device)
# Transforms to resize and augment an image + normalize to match CLIP's training data
tfms = transforms.Compose(
[
transforms.RandomResizedCrop(224), # Random CROP each time
transforms.RandomAffine(
5
), # One possible random augmentation: skews the image
transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
# CLIP guidance function
def clip_loss(image, text_features):
image_features = clip_model.encode_image(
tfms(image)
) # Note: applies the above transforms
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
dists = (
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
) # Squared Great Circle Distance
return dists.mean()
# Sample generator loop
#-------------------------------------------------------------------------------
def generate(color,
color_loss_scale,
num_examples=4,
seed=None,
prompt=None,
prompt_loss_scale=None,
prompt_n_cuts=None,
inference_steps=50,
):
scheduler.set_timesteps(num_inference_steps=inference_steps)
if seed:
torch.manual_seed(seed)
if prompt:
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1)
x = torch.randn(num_examples, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# color loss
loss = color_loss(x0, target_color) * color_loss_scale
cond_color_grad = -torch.autograd.grad(loss, x)[0]
# Modify x based solely on the color gradient -> x_cond
x_cond = x.detach() + cond_color_grad
# prompt loss (modify x_cond with cond_prompt_grad) based on
# the original x (not modifified previously with cond_color_grad)
if prompt:
cond_prompt_grad = 0
for cut in range(prompt_n_cuts):
# Set requires grad on x
x = x.detach().requires_grad_()
# Get the predicted x0:
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculate loss
prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale
# Get gradient (scale by n_cuts since we want the average)
cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts
# Modify x based on this gradient
alpha_bar = scheduler.alphas_cumprod[i]
x_cond = (
x_cond + cond_prompt_grad * alpha_bar.sqrt()
) # Note the additional scaling factor here!
x = scheduler.step(noise_pred, t, x_cond).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
im = Image.fromarray(np.array(im * 255).astype(np.uint8))
#im.save("test.jpeg")
return im
# GRADIO Interface
#-------------------------------------------------------------------------------
TITLE="Ukiyo-e postal generator service 🎴!"
DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo"
CSS = ".output-image, .input-image, .image-preview {height: 250px !important}"
# See the gradio docs for the types of inputs and outputs available
inputs = [
gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), # Add any inputs you need here
gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7),
gr.Slider(label="num_examples (# images generated)", minimum=2, maximum=12, value=2, step=4),
gr.Number(label="seed (reproducibility and experimentation)", value=666),
gr.Text(label="Text prompt (optional)", value=None),
gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10),
gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4),
gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", minimum=40, maximum=60, value=40, step=1),
]
outputs = gr.Image(label="result")
# And the minimal interface
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
css=CSS,
#examples=[
#["#DF5C16", 6.7, 12, 666, None, None, None, 40],
#["#C01660", 13.5, 12, 1990, None, None, None, 40],
#["#44CCAA", 8.9, 12, 1512, None, None, None, 40],
#["#39A291", 5.0, 2, 666, "A sakura tree", 60, 4, 40],
#["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52],
#["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 4, 47],
#],
title=TITLE,
description=DESCRIPTION,
)
if __name__ == "__main__":
demo.launch(enable_queue=True)
|