Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,10 @@
|
|
|
|
1 |
import os
|
2 |
-
|
3 |
-
os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
|
4 |
-
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
|
5 |
-
|
6 |
import argparse
|
7 |
from functools import partial
|
8 |
from pathlib import Path
|
9 |
import sys
|
10 |
-
|
11 |
-
sys.path.append('./cloob-latent-diffusion/cloob-training')
|
12 |
-
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
|
13 |
-
sys.path.append('./cloob-latent-diffusion/taming-transformers')
|
14 |
-
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
|
15 |
from omegaconf import OmegaConf
|
16 |
from PIL import Image
|
17 |
import torch
|
@@ -26,17 +19,21 @@ import ldm.models.autoencoder
|
|
26 |
from diffusion import sampling, utils
|
27 |
import train_latent_diffusion as train
|
28 |
from huggingface_hub import hf_hub_url, cached_download
|
29 |
-
import
|
30 |
|
31 |
-
# Download the model files
|
|
|
32 |
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
|
33 |
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
34 |
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
|
39 |
def parse_prompt(prompt, default_weight=3.):
|
|
|
|
|
|
|
40 |
if prompt.startswith('http://') or prompt.startswith('https://'):
|
41 |
vals = prompt.rsplit(':', 2)
|
42 |
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
@@ -45,31 +42,35 @@ def parse_prompt(prompt, default_weight=3.):
|
|
45 |
vals = vals + ['', default_weight][len(vals):]
|
46 |
return vals[0], float(vals[1])
|
47 |
|
48 |
-
|
49 |
def resize_and_center_crop(image, size):
|
|
|
|
|
|
|
50 |
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
51 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
52 |
return TF.center_crop(image, size[::-1])
|
53 |
|
54 |
|
55 |
-
#
|
|
|
|
|
56 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
57 |
print('Using device:', device)
|
58 |
-
print('loading models')
|
59 |
|
60 |
-
#
|
61 |
ae_config = OmegaConf.load(ae_config_path)
|
62 |
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
|
63 |
ae_model.eval().requires_grad_(False).to(device)
|
64 |
ae_model.load_state_dict(torch.load(ae_model_path))
|
65 |
n_ch, side_y, side_x = 4, 32, 32
|
66 |
|
67 |
-
#
|
68 |
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
69 |
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
70 |
model = model.to(device).eval().requires_grad_(False)
|
71 |
|
72 |
-
# CLOOB
|
73 |
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
|
74 |
cloob = model_pt.get_pt_model(cloob_config)
|
75 |
checkpoint = pretrained.download_checkpoint(cloob_config)
|
@@ -77,93 +78,98 @@ cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
|
|
77 |
cloob.eval().requires_grad_(False).to(device)
|
78 |
|
79 |
|
80 |
-
# The key function:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
146 |
|
147 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
im_prompts = [
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
167 |
["Virgin and Child, in the style of Jacopo Bellini"],
|
168 |
["Art Nouveau, in the style of John Singer Sargent"],
|
169 |
["Neoclassicism, in the style of Gustav Klimt"],
|
@@ -212,9 +218,10 @@ iface = gr.Interface(fn=gen_ims,
|
|
212 |
["Futurism, in the style of Zdzislaw Beksinski"],
|
213 |
["Aaron Wacker, oil on canvas"],
|
214 |
],
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
)
|
220 |
-
|
|
|
|
|
|
1 |
+
# 🚀 Import all necessary libraries
|
2 |
import os
|
|
|
|
|
|
|
|
|
3 |
import argparse
|
4 |
from functools import partial
|
5 |
from pathlib import Path
|
6 |
import sys
|
7 |
+
import random
|
|
|
|
|
|
|
|
|
8 |
from omegaconf import OmegaConf
|
9 |
from PIL import Image
|
10 |
import torch
|
|
|
19 |
from diffusion import sampling, utils
|
20 |
import train_latent_diffusion as train
|
21 |
from huggingface_hub import hf_hub_url, cached_download
|
22 |
+
import gradio as gr # 🎨 The magic canvas for AI-powered image generation!
|
23 |
|
24 |
+
# 🖼️ Download the necessary model files
|
25 |
+
# These files are loaded from HuggingFace's repository
|
26 |
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
|
27 |
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
28 |
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
29 |
|
30 |
+
# 📐 Utility Functions: Math and images, what could go wrong?
|
31 |
+
# These functions help parse prompts and resize/crop images to fit nicely
|
32 |
|
33 |
def parse_prompt(prompt, default_weight=3.):
|
34 |
+
"""
|
35 |
+
🎯 Parses a prompt into text and weight.
|
36 |
+
"""
|
37 |
if prompt.startswith('http://') or prompt.startswith('https://'):
|
38 |
vals = prompt.rsplit(':', 2)
|
39 |
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
|
|
42 |
vals = vals + ['', default_weight][len(vals):]
|
43 |
return vals[0], float(vals[1])
|
44 |
|
|
|
45 |
def resize_and_center_crop(image, size):
|
46 |
+
"""
|
47 |
+
✂️ Resize and crop image to center it beautifully.
|
48 |
+
"""
|
49 |
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
50 |
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
51 |
return TF.center_crop(image, size[::-1])
|
52 |
|
53 |
|
54 |
+
# 🧠 Model loading: the brain of our operation! 🔥
|
55 |
+
# Load all the models: autoencoder, diffusion, and CLOOB
|
56 |
+
|
57 |
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
58 |
print('Using device:', device)
|
59 |
+
print('loading models... 🛠️')
|
60 |
|
61 |
+
# 🔧 Autoencoder Setup: Let’s decode the madness into images
|
62 |
ae_config = OmegaConf.load(ae_config_path)
|
63 |
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
|
64 |
ae_model.eval().requires_grad_(False).to(device)
|
65 |
ae_model.load_state_dict(torch.load(ae_model_path))
|
66 |
n_ch, side_y, side_x = 4, 32, 32
|
67 |
|
68 |
+
# 🌀 Diffusion Model Setup: The artist behind the scenes
|
69 |
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
70 |
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
71 |
model = model.to(device).eval().requires_grad_(False)
|
72 |
|
73 |
+
# 👁️ CLOOB Setup: Our vision model to understand art in human style
|
74 |
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
|
75 |
cloob = model_pt.get_pt_model(cloob_config)
|
76 |
checkpoint = pretrained.download_checkpoint(cloob_config)
|
|
|
78 |
cloob.eval().requires_grad_(False).to(device)
|
79 |
|
80 |
|
81 |
+
# 🎨 The key function: Where the magic happens!
|
82 |
+
# This is where we generate images based on text and image prompts
|
83 |
+
|
84 |
+
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15, method='plms', eta=None):
|
85 |
+
"""
|
86 |
+
🖼️ Generates a list of PIL images based on given text and image prompts.
|
87 |
+
"""
|
88 |
+
zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
|
89 |
+
target_embeds, weights = [zero_embed], []
|
90 |
+
|
91 |
+
# Parse text prompts
|
92 |
+
for prompt in prompts:
|
93 |
+
txt, weight = parse_prompt(prompt)
|
94 |
+
target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
|
95 |
+
weights.append(weight)
|
96 |
+
|
97 |
+
# Parse image prompts
|
98 |
+
for prompt in images:
|
99 |
+
path, weight = parse_prompt(prompt)
|
100 |
+
img = Image.open(utils.fetch(path)).convert('RGB')
|
101 |
+
clip_size = cloob.config['image_encoder']['image_size']
|
102 |
+
img = resize_and_center_crop(img, (clip_size, clip_size))
|
103 |
+
batch = TF.to_tensor(img)[None].to(device)
|
104 |
+
embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
|
105 |
+
target_embeds.append(embed)
|
106 |
+
weights.append(weight)
|
107 |
+
|
108 |
+
# Adjust weights and set seed
|
109 |
+
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
110 |
+
torch.manual_seed(seed)
|
111 |
+
|
112 |
+
# 💡 Model function with classifier-free guidance
|
113 |
+
def cfg_model_fn(x, t):
|
114 |
+
n = x.shape[0]
|
115 |
+
n_conds = len(target_embeds)
|
116 |
+
x_in = x.repeat([n_conds, 1, 1, 1])
|
117 |
+
t_in = t.repeat([n_conds])
|
118 |
+
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
119 |
+
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
|
120 |
+
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
121 |
+
return v
|
122 |
+
|
123 |
+
# 🎞️ Run the sampler to generate images
|
124 |
+
def run(x, steps):
|
125 |
+
if method == 'ddpm':
|
126 |
+
return sampling.sample(cfg_model_fn, x, steps, 1., {})
|
127 |
+
if method == 'ddim':
|
128 |
+
return sampling.sample(cfg_model_fn, x, steps, eta, {})
|
129 |
+
if method == 'plms':
|
130 |
+
return sampling.plms_sample(cfg_model_fn, x, steps, {})
|
131 |
+
assert False
|
132 |
+
|
133 |
+
# 🏃♂️ Generate the output images
|
134 |
+
batch_size = n
|
135 |
+
x = torch.randn([n, n_ch, side_y, side_x], device=device)
|
136 |
+
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
137 |
+
pil_ims = []
|
138 |
+
for i in trange(0, n, batch_size):
|
139 |
+
cur_batch_size = min(n - i, batch_size)
|
140 |
+
out_latents = run(x[i:i + cur_batch_size], steps)
|
141 |
+
outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
|
142 |
+
for j, out in enumerate(outs):
|
143 |
+
pil_ims.append(utils.to_pil_image(out))
|
144 |
+
|
145 |
+
return pil_ims
|
146 |
+
|
147 |
+
|
148 |
+
# 🖌️ Interface: Gradio's brush to paint the UI
|
149 |
+
# Gradio is used here to create a user-friendly interface for art generation.
|
150 |
|
151 |
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
152 |
+
"""
|
153 |
+
💡 Gradio function to wrap image generation.
|
154 |
+
"""
|
155 |
+
if seed is None:
|
156 |
+
seed = random.randint(0, 10000)
|
157 |
+
prompts = [prompt]
|
158 |
+
im_prompts = []
|
159 |
+
if im_prompt is not None:
|
160 |
+
im_prompts = [im_prompt]
|
161 |
+
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
162 |
+
return pil_ims[0]
|
163 |
+
|
164 |
+
# 🖼️ Gradio UI: The interface where users can input text or image prompts
|
165 |
+
iface = gr.Interface(
|
166 |
+
fn=gen_ims,
|
167 |
+
inputs=[
|
168 |
+
gr.Textbox(label="Text prompt"),
|
169 |
+
gr.Image(optional=True, label="Image prompt", type='filepath')
|
170 |
+
],
|
171 |
+
outputs=gr.Image(type="pil", label="Generated Image"),
|
172 |
+
examples=[
|
173 |
["Virgin and Child, in the style of Jacopo Bellini"],
|
174 |
["Art Nouveau, in the style of John Singer Sargent"],
|
175 |
["Neoclassicism, in the style of Gustav Klimt"],
|
|
|
218 |
["Futurism, in the style of Zdzislaw Beksinski"],
|
219 |
["Aaron Wacker, oil on canvas"],
|
220 |
],
|
221 |
+
title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia',
|
222 |
+
description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
|
223 |
+
article='Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa).'
|
|
|
224 |
)
|
225 |
+
|
226 |
+
# 🚀 Launch the Gradio interface
|
227 |
+
iface.launch(enable_queue=True)
|