import open_clip import gradio as gr import numpy as np import torch import torchvision from tqdm.auto import tqdm from PIL import Image, ImageColor from torchvision import transforms from diffusers import DDIMScheduler, DDPMPipeline device = ( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) # Load the pretrained pipeline pipeline_name = "alkzar90/sd-class-ukiyo-e-256" image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) # Sample some images with a DDIM Scheduler over 40 steps scheduler = DDIMScheduler.from_pretrained(pipeline_name) scheduler.set_timesteps(num_inference_steps=40) # Color guidance #------------------------------------------------------------------------------- # Color guidance function def color_loss(images, target_color=(0.1, 0.9, 0.5)): """Given a target color (R, G, B) return a loss for how far away on average the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)""" target = ( torch.tensor(target_color).to(images.device) * 2 - 1 ) # Map target color to (-1, 1) target = target[ None, :, None, None ] # Get shape right to work with the images (b, c, h, w) error = torch.abs( images - target ).mean() # Mean absolute difference between the image pixels and the target color return error # CLIP guidance #------------------------------------------------------------------------------- clip_model, _, preprocess = open_clip.create_model_and_transforms( "ViT-B-32", pretrained="openai" ) clip_model.to(device) # Transforms to resize and augment an image + normalize to match CLIP's training data tfms = transforms.Compose( [ transforms.RandomResizedCrop(224), # Random CROP each time transforms.RandomAffine( 5 ), # One possible random augmentation: skews the image transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) # CLIP guidance function def clip_loss(image, text_features): image_features = clip_model.encode_image( tfms(image) ) # Note: applies the above transforms input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) dists = ( input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) ) # Squared Great Circle Distance return dists.mean() # Sample generator loop #------------------------------------------------------------------------------- def generate(color, color_loss_scale, num_examples=4, seed=None, prompt=None, prompt_loss_scale=None, prompt_n_cuts=None, inference_steps=50, ): scheduler.set_timesteps(num_inference_steps=inference_steps) if seed: torch.manual_seed(seed) if prompt: text = open_clip.tokenize([prompt]).to(device) with torch.no_grad(), torch.cuda.amp.autocast(): text_features = clip_model.encode_text(text) target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1) x = torch.randn(num_examples, 3, 256, 256).to(device) for i, t in tqdm(enumerate(scheduler.timesteps)): model_input = scheduler.scale_model_input(x, t) with torch.no_grad(): noise_pred = image_pipe.unet(model_input, t)["sample"] x = x.detach().requires_grad_() x0 = scheduler.step(noise_pred, t, x).pred_original_sample # color loss loss = color_loss(x0, target_color) * color_loss_scale cond_color_grad = -torch.autograd.grad(loss, x)[0] # Modify x based solely on the color gradient -> x_cond x_cond = x.detach() + cond_color_grad # prompt loss (modify x_cond with cond_prompt_grad) based on # the original x (not modifified previously with cond_color_grad) if prompt: cond_prompt_grad = 0 for cut in range(prompt_n_cuts): # Set requires grad on x x = x.detach().requires_grad_() # Get the predicted x0: x0 = scheduler.step(noise_pred, t, x).pred_original_sample # Calculate loss prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale # Get gradient (scale by n_cuts since we want the average) cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts # Modify x based on this gradient alpha_bar = scheduler.alphas_cumprod[i] x_cond = ( x_cond + cond_prompt_grad * alpha_bar.sqrt() ) # Note the additional scaling factor here! x = scheduler.step(noise_pred, t, x_cond).prev_sample grid = torchvision.utils.make_grid(x, nrow=4) im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 im = Image.fromarray(np.array(im * 255).astype(np.uint8)) im.save("test.jpeg") return im # GRADIO Interface #------------------------------------------------------------------------------- TITLE="Ukiyo-e postal generator service 🎴!" DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo" CSS = ".output-image, .input-image, .image-preview {height: 250px !important}" # See the gradio docs for the types of inputs and outputs available inputs = [ gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), # Add any inputs you need here gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7), gr.Slider(label="num_examples (# images generated)", minimum=4, maximum=12, value=8, step=4), gr.Number(label="seed (reproducibility and experimentation)", value=666), gr.Text(label="Text prompt (optional)", value=None), gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10), gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4), gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", minimum=40, maximum=60, value=40, step=1), ] outputs = gr.Image(label="result") # And the minimal interface demo = gr.Interface( fn=generate, inputs=inputs, outputs=outputs, css=CSS, examples=[ #["#DF5C16", 6.7, 12, 666, None, None, None, 40], #["#C01660", 13.5, 12, 1990, None, None, None, 40], #["#44CCAA", 8.9, 12, 1512, None, None, None, 40], ["#39A291", 5.0, 8, 666, "A sakura tree", 60, 4, 52], #["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52], #["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 4, 47], ], title=TITLE, description=DESCRIPTION, ) if __name__ == "__main__": demo.launch(enable_queue=True)