Spaces:
Runtime error
Runtime error
# ๐ Import all necessary libraries | |
import os | |
import argparse | |
from functools import partial | |
from pathlib import Path | |
import sys | |
import random | |
from omegaconf import OmegaConf | |
from PIL import Image | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from torchvision.transforms import functional as TF | |
from tqdm import trange | |
from cloob_training import model_pt, pretrained | |
import ldm.models.autoencoder | |
from diffusion import sampling, utils | |
import train_latent_diffusion as train | |
from huggingface_hub import hf_hub_url, cached_download | |
import gradio as gr # ๐จ The magic canvas for AI-powered image generation! | |
# ๐ผ๏ธ Download the necessary model files | |
# These files are loaded from HuggingFace's repository | |
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt")) | |
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt")) | |
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml")) | |
# ๐ Utility Functions: Math and images, what could go wrong? | |
# These functions help parse prompts and resize/crop images to fit nicely | |
def parse_prompt(prompt, default_weight=3.): | |
""" | |
๐ฏ Parses a prompt into text and weight. | |
""" | |
if prompt.startswith('http://') or prompt.startswith('https://'): | |
vals = prompt.rsplit(':', 2) | |
vals = [vals[0] + ':' + vals[1], *vals[2:]] | |
else: | |
vals = prompt.rsplit(':', 1) | |
vals = vals + ['', default_weight][len(vals):] | |
return vals[0], float(vals[1]) | |
def resize_and_center_crop(image, size): | |
""" | |
โ๏ธ Resize and crop image to center it beautifully. | |
""" | |
fac = max(size[0] / image.size[0], size[1] / image.size[1]) | |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS) | |
return TF.center_crop(image, size[::-1]) | |
# ๐ง Model loading: the brain of our operation! ๐ฅ | |
# Load all the models: autoencoder, diffusion, and CLOOB | |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
print('Using device:', device) | |
print('loading models... ๐ ๏ธ') | |
# ๐ง Autoencoder Setup: Letโs decode the madness into images | |
ae_config = OmegaConf.load(ae_config_path) | |
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params) | |
ae_model.eval().requires_grad_(False).to(device) | |
ae_model.load_state_dict(torch.load(ae_model_path)) | |
n_ch, side_y, side_x = 4, 32, 32 | |
# ๐ Diffusion Model Setup: The artist behind the scenes | |
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084)) | |
model.load_state_dict(torch.load(checkpoint, map_location='cpu')) | |
model = model.to(device).eval().requires_grad_(False) | |
# ๐๏ธ CLOOB Setup: Our vision model to understand art in human style | |
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs') | |
cloob = model_pt.get_pt_model(cloob_config) | |
checkpoint = pretrained.download_checkpoint(cloob_config) | |
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint)) | |
cloob.eval().requires_grad_(False).to(device) | |
# ๐จ The key function: Where the magic happens! | |
# This is where we generate images based on text and image prompts | |
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='plms', eta=None): | |
""" | |
๐ผ๏ธ Generates a list of PIL images based on given text and image prompts. | |
""" | |
zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device) | |
target_embeds, weights = [zero_embed], [] | |
# Parse text prompts | |
for prompt in prompts: | |
txt, weight = parse_prompt(prompt) | |
target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float()) | |
weights.append(weight) | |
# Parse image prompts | |
for prompt in images: | |
path, weight = parse_prompt(prompt) | |
img = Image.open(utils.fetch(path)).convert('RGB') | |
clip_size = cloob.config['image_encoder']['image_size'] | |
img = resize_and_center_crop(img, (clip_size, clip_size)) | |
batch = TF.to_tensor(img)[None].to(device) | |
embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1) | |
target_embeds.append(embed) | |
weights.append(weight) | |
# Adjust weights and set seed | |
weights = torch.tensor([1 - sum(weights), *weights], device=device) | |
torch.manual_seed(seed) | |
# ๐ก Model function with classifier-free guidance | |
def cfg_model_fn(x, t): | |
n = x.shape[0] | |
n_conds = len(target_embeds) | |
x_in = x.repeat([n_conds, 1, 1, 1]) | |
t_in = t.repeat([n_conds]) | |
embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0) | |
vs = model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]]) | |
v = vs.mul(weights[:, None, None, None, None]).sum(0) | |
return v | |
# ๐๏ธ Run the sampler to generate images | |
def run(x, steps): | |
if method == 'ddpm': | |
return sampling.sample(cfg_model_fn, x, steps, 1., {}) | |
if method == 'ddim': | |
return sampling.sample(cfg_model_fn, x, steps, eta, {}) | |
if method == 'plms': | |
return sampling.plms_sample(cfg_model_fn, x, steps, {}) | |
assert False | |
# ๐โโ๏ธ Generate the output images | |
batch_size = n | |
x = torch.randn([n, n_ch, side_y, side_x], device=device) | |
t = torch.linspace(1, 0, steps + 1, device=device)[:-1] | |
pil_ims = [] | |
for i in trange(0, n, batch_size): | |
cur_batch_size = min(n - i, batch_size) | |
out_latents = run(x[i:i + cur_batch_size], steps) | |
outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device)) | |
for j, out in enumerate(outs): | |
pil_ims.append(utils.to_pil_image(out)) | |
return pil_ims | |
# ๐๏ธ Interface: Gradio's brush to paint the UI | |
# Gradio is used here to create a user-friendly interface for art generation. | |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'): | |
""" | |
๐ก Gradio function to wrap image generation. | |
""" | |
if seed is None: | |
seed = random.randint(0, 10000) | |
prompts = [prompt] | |
im_prompts = [] | |
if im_prompt is not None: | |
im_prompts = [im_prompt] | |
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method) | |
return pil_ims[0] | |
# ๐ผ๏ธ Gradio UI: The interface where users can input text or image prompts | |
iface = gr.Interface( | |
fn=gen_ims, | |
inputs=[ | |
gr.Textbox(label="Text prompt"), | |
gr.Image(optional=True, label="Image prompt", type='filepath') | |
], | |
outputs=gr.Image(type="pil", label="Generated Image"), | |
examples=[ | |
["Virgin and Child, in the style of Jacopo Bellini"], | |
["Art Nouveau, in the style of John Singer Sargent"], | |
["Neoclassicism, in the style of Gustav Klimt"], | |
["Abstract Art, in the style of M.C. Escher"], | |
['Surrealism, in the style of Salvador Dali'], | |
["Romanesque Art, in the style of Leonardo da Vinci"], | |
["landscape"], | |
["portrait"], | |
["sculpture"], | |
["photo"], | |
["figurative"], | |
["illustration"], | |
["still life"], | |
["cityscape"], | |
["marina"], | |
["animal painting"], | |
["graffiti"], | |
["mythological painting"], | |
["battle painting"], | |
["self-portrait"], | |
["Impressionism, oil on canvas"], | |
["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"], | |
["Moon Light Sonata by Basuki Abdullah"], | |
["Two Trees by M.C. Escher"], | |
["Futurism, in the style of Wassily Kandinsky"], | |
["Surrealism, in the style of Edgar Degas"], | |
["Expressionism, in the style of Wassily Kandinsky"], | |
["Futurism, in the style of Egon Schiele"], | |
["Cubism, in the style of Gustav Klimt"], | |
["Op Art, in the style of Marc Chagall"], | |
["Romanticism, in the style of M.C. Escher"], | |
["Futurism, in the style of M.C. Escher"], | |
["Mannerism, in the style of Paul Klee"], | |
["High Renaissance, in the style of Rembrandt"], | |
["Magic Realism, in the style of Gustave Dore"], | |
["Realism, in the style of Jean-Michel Basquiat"], | |
["Art Nouveau, in the style of Paul Gauguin"], | |
["Avant-garde, in the style of Pierre-Auguste Renoir"], | |
["Baroque, in the style of Edward Hopper"], | |
["Post-Impressionism, in the style of Wassily Kandinsky"], | |
["Naturalism, in the style of Rene Magritte"], | |
["Constructivism, in the style of Paul Cezanne"], | |
["Abstract Expressionism, in the style of Henri Matisse"], | |
["Pop Art, in the style of Vincent van Gogh"], | |
["Futurism, in the style of Zdzislaw Beksinski"], | |
["Aaron Wacker, oil on canvas"] | |
], | |
title='Art Generator and Style Mixer from ๐ง Cloob and ๐จ WikiArt - Visual Art Encyclopedia', | |
description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts", | |
article='Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa).' | |
) | |
# ๐ Launch the Gradio interface | |
iface.launch(enable_queue=True) | |