File size: 9,001 Bytes
4849edf
2b4eb7c
 
 
 
 
4849edf
2b4eb7c
 
 
 
 
 
 
 
 
 
 
 
 
4849edf
2b4eb7c
4849edf
 
2b4eb7c
 
 
 
4849edf
 
072f36e
2b4eb7c
4849edf
 
 
2b4eb7c
 
 
 
 
 
 
 
 
4849edf
 
 
2b4eb7c
 
 
 
 
4849edf
 
 
2b4eb7c
 
4849edf
2b4eb7c
4849edf
2b4eb7c
 
 
 
 
 
4849edf
2b4eb7c
 
 
 
4849edf
2b4eb7c
 
 
 
 
 
 
4849edf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e27e64b
 
4849edf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b4eb7c
 
4849edf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0970f1b
ef52fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e27e64b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4849edf
 
 
2b4eb7c
4849edf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# 🚀 Import all necessary libraries
import os
import argparse
from functools import partial
from pathlib import Path
import sys
import random
from omegaconf import OmegaConf
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm import trange
from cloob_training import model_pt, pretrained
import ldm.models.autoencoder
from diffusion import sampling, utils
import train_latent_diffusion as train
from huggingface_hub import hf_hub_url, cached_download
import gradio as gr  # 🎨 The magic canvas for AI-powered image generation!

# 🖼️ Download the necessary model files
# These files are loaded from HuggingFace's repository
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))

# 📐 Utility Functions: Math and images, what could go wrong?
# These functions help parse prompts and resize/crop images to fit nicely

def parse_prompt(prompt, default_weight=3.):
    """
    🎯 Parses a prompt into text and weight.
    """
    if prompt.startswith('http://') or prompt.startswith('https://'):
        vals = prompt.rsplit(':', 2)
        vals = [vals[0] + ':' + vals[1], *vals[2:]]
    else:
        vals = prompt.rsplit(':', 1)
    vals = vals + ['', default_weight][len(vals):]
    return vals[0], float(vals[1])

def resize_and_center_crop(image, size):
    """
    ✂️ Resize and crop image to center it beautifully.
    """
    fac = max(size[0] / image.size[0], size[1] / image.size[1])
    image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
    return TF.center_crop(image, size[::-1])


# 🧠 Model loading: the brain of our operation! 🔥
# Load all the models: autoencoder, diffusion, and CLOOB

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('loading models... 🛠️')

# 🔧 Autoencoder Setup: Let’s decode the madness into images
ae_config = OmegaConf.load(ae_config_path)
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
ae_model.eval().requires_grad_(False).to(device)
ae_model.load_state_dict(torch.load(ae_model_path))
n_ch, side_y, side_x = 4, 32, 32

# 🌀 Diffusion Model Setup: The artist behind the scenes
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
model = model.to(device).eval().requires_grad_(False)

# 👁️ CLOOB Setup: Our vision model to understand art in human style
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
cloob = model_pt.get_pt_model(cloob_config)
checkpoint = pretrained.download_checkpoint(cloob_config)
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
cloob.eval().requires_grad_(False).to(device)


# 🎨 The key function: Where the magic happens!
# This is where we generate images based on text and image prompts

def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='plms', eta=None):
    """
    🖼️ Generates a list of PIL images based on given text and image prompts.
    """
    zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
    target_embeds, weights = [zero_embed], []

    # Parse text prompts
    for prompt in prompts:
        txt, weight = parse_prompt(prompt)
        target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
        weights.append(weight)

    # Parse image prompts
    for prompt in images:
        path, weight = parse_prompt(prompt)
        img = Image.open(utils.fetch(path)).convert('RGB')
        clip_size = cloob.config['image_encoder']['image_size']
        img = resize_and_center_crop(img, (clip_size, clip_size))
        batch = TF.to_tensor(img)[None].to(device)
        embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
        target_embeds.append(embed)
        weights.append(weight)

    # Adjust weights and set seed
    weights = torch.tensor([1 - sum(weights), *weights], device=device)
    torch.manual_seed(seed)

    # 💡 Model function with classifier-free guidance
    def cfg_model_fn(x, t):
        n = x.shape[0]
        n_conds = len(target_embeds)
        x_in = x.repeat([n_conds, 1, 1, 1])
        t_in = t.repeat([n_conds])
        embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
        vs = model(x_in, t_in, embed_in).view([n_conds, n, *x.shape[1:]])
        v = vs.mul(weights[:, None, None, None, None]).sum(0)
        return v

    # 🎞️ Run the sampler to generate images
    def run(x, steps):
        if method == 'ddpm':
            return sampling.sample(cfg_model_fn, x, steps, 1., {})
        if method == 'ddim':
            return sampling.sample(cfg_model_fn, x, steps, eta, {})
        if method == 'plms':
            return sampling.plms_sample(cfg_model_fn, x, steps, {})
        assert False

    # 🏃‍♂️ Generate the output images
    batch_size = n
    x = torch.randn([n, n_ch, side_y, side_x], device=device)
    t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
    pil_ims = []
    for i in trange(0, n, batch_size):
        cur_batch_size = min(n - i, batch_size)
        out_latents = run(x[i:i + cur_batch_size], steps)
        outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
        for j, out in enumerate(outs):
            pil_ims.append(utils.to_pil_image(out))

    return pil_ims


# 🖌️ Interface: Gradio's brush to paint the UI
# Gradio is used here to create a user-friendly interface for art generation.

def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
    """
    💡 Gradio function to wrap image generation.
    """
    if seed is None:
        seed = random.randint(0, 10000)
    prompts = [prompt]
    im_prompts = []
    if im_prompt is not None:
        im_prompts = [im_prompt]
    pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
    return pil_ims[0]

# 🖼️ Gradio UI: The interface where users can input text or image prompts
iface = gr.Interface(
    fn=gen_ims,
    inputs=[
        gr.Textbox(label="Text prompt"),
        gr.Image(optional=True, label="Image prompt", type='filepath')
    ],
    outputs=gr.Image(type="pil", label="Generated Image"),
    examples=[
  ["Virgin and Child, in the style of Jacopo Bellini"],
  ["Art Nouveau, in the style of John Singer Sargent"], 
  ["Neoclassicism, in the style of Gustav Klimt"], 
  ["Abstract Art, in the style of M.C. Escher"], 
  ['Surrealism, in the style of Salvador Dali'], 
   ["Romanesque Art, in the style of Leonardo da Vinci"], 
   ["landscape"],
    ["portrait"],
    ["sculpture"],
    ["photo"],
    ["figurative"],
    ["illustration"],
    ["still life"],
    ["cityscape"],
    ["marina"],
    ["animal painting"],
    ["graffiti"],
    ["mythological painting"],
    ["battle painting"],
    ["self-portrait"],
    ["Impressionism, oil on canvas"],
    ["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"],
    ["Moon Light Sonata by Basuki Abdullah"],
    ["Two Trees by M.C. Escher"],
    ["Futurism, in the style of Wassily Kandinsky"], 
    ["Surrealism, in the style of Edgar Degas"], 
    ["Expressionism, in the style of Wassily Kandinsky"], 
    ["Futurism, in the style of Egon Schiele"], 
    ["Cubism, in the style of Gustav Klimt"], 
    ["Op Art, in the style of Marc Chagall"], 
    ["Romanticism, in the style of M.C. Escher"], 
    ["Futurism, in the style of M.C. Escher"], 
    ["Mannerism, in the style of Paul Klee"], 
    ["High Renaissance, in the style of Rembrandt"], 
    ["Magic Realism, in the style of Gustave Dore"], 
    ["Realism, in the style of Jean-Michel Basquiat"], 
    ["Art Nouveau, in the style of Paul Gauguin"], 
    ["Avant-garde, in the style of Pierre-Auguste Renoir"], 
    ["Baroque, in the style of Edward Hopper"], 
    ["Post-Impressionism, in the style of Wassily Kandinsky"], 
    ["Naturalism, in the style of Rene Magritte"], 
    ["Constructivism, in the style of Paul Cezanne"], 
    ["Abstract Expressionism, in the style of Henri Matisse"], 
    ["Pop Art, in the style of Vincent van Gogh"], 
    ["Futurism, in the style of Zdzislaw Beksinski"], 
    ["Aaron Wacker, oil on canvas"]
    ],
    title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia',
    description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
    article='Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa).'
)

# 🚀 Launch the Gradio interface
iface.launch(enable_queue=True)