Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
|
4 |
+
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
from functools import partial
|
8 |
+
from pathlib import Path
|
9 |
+
import sys
|
10 |
+
sys.path.append('./cloob-latent-diffusion')
|
11 |
+
sys.path.append('./cloob-latent-diffusion/cloob-training')
|
12 |
+
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
|
13 |
+
sys.path.append('./cloob-latent-diffusion/taming-transformers')
|
14 |
+
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from PIL import Image
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from torchvision import transforms
|
21 |
+
from torchvision.transforms import functional as TF
|
22 |
+
from tqdm import trange
|
23 |
+
from CLIP import clip
|
24 |
+
from cloob_training import model_pt, pretrained
|
25 |
+
import ldm.models.autoencoder
|
26 |
+
from diffusion import sampling, utils
|
27 |
+
import train_latent_diffusion as train
|
28 |
+
from huggingface_hub import hf_hub_url, cached_download
|
29 |
+
import random
|
30 |
+
|
31 |
+
# Download the model files
|
32 |
+
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
|
33 |
+
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
34 |
+
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
35 |
+
|
36 |
+
# Define a few utility functions
|
37 |
+
|
38 |
+
def parse_prompt(prompt, default_weight=3.):
|
39 |
+
if prompt.startswith('http://') or prompt.startswith('https://'):
|
40 |
+
vals = prompt.rsplit(':', 2)
|
41 |
+
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
42 |
+
else:
|
43 |
+
vals = prompt.rsplit(':', 1)
|
44 |
+
vals = vals + ['', default_weight][len(vals):]
|
45 |
+
return vals[0], float(vals[1])
|
46 |
+
|
47 |
+
|
48 |
+
def resize_and_center_crop(image, size):
|
49 |
+
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
50 |
+
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
51 |
+
return TF.center_crop(image, size[::-1])
|
52 |
+
|
53 |
+
|
54 |
+
# Load the models
|
55 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
56 |
+
print('Using device:', device)
|
57 |
+
print('loading models')
|
58 |
+
|
59 |
+
# autoencoder
|
60 |
+
ae_config = OmegaConf.load(ae_config_path)
|
61 |
+
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
|
62 |
+
ae_model.eval().requires_grad_(False).to(device)
|
63 |
+
ae_model.load_state_dict(torch.load(ae_model_path))
|
64 |
+
n_ch, side_y, side_x = 4, 32, 32
|
65 |
+
|
66 |
+
# diffusion model
|
67 |
+
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
68 |
+
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
69 |
+
model = model.to(device).eval().requires_grad_(False)
|
70 |
+
|
71 |
+
# CLOOB
|
72 |
+
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
|
73 |
+
cloob = model_pt.get_pt_model(cloob_config)
|
74 |
+
checkpoint = pretrained.download_checkpoint(cloob_config)
|
75 |
+
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
|
76 |
+
cloob.eval().requires_grad_(False).to(device)
|
77 |
+
|
78 |
+
|
79 |
+
# The key function: returns a list of n PIL images
|
80 |
+
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
|
81 |
+
method='plms', eta=None):
|
82 |
+
zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
|
83 |
+
target_embeds, weights = [zero_embed], []
|
84 |
+
|
85 |
+
for prompt in prompts:
|
86 |
+
txt, weight = parse_prompt(prompt)
|
87 |
+
target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
|
88 |
+
weights.append(weight)
|
89 |
+
|
90 |
+
for prompt in images:
|
91 |
+
path, weight = parse_prompt(prompt)
|
92 |
+
img = Image.open(utils.fetch(path)).convert('RGB')
|
93 |
+
clip_size = cloob.config['image_encoder']['image_size']
|
94 |
+
img = resize_and_center_crop(img, (clip_size, clip_size))
|
95 |
+
batch = TF.to_tensor(img)[None].to(device)
|
96 |
+
embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
|
97 |
+
target_embeds.append(embed)
|
98 |
+
weights.append(weight)
|
99 |
+
|
100 |
+
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
101 |
+
|
102 |
+
torch.manual_seed(seed)
|
103 |
+
|
104 |
+
def cfg_model_fn(x, t):
|
105 |
+
n = x.shape[0]
|
106 |
+
n_conds = len(target_embeds)
|
107 |
+
x_in = x.repeat([n_conds, 1, 1, 1])
|
108 |
+
t_in = t.repeat([n_conds])
|
109 |
+
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
110 |
+
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
|
111 |
+
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
112 |
+
return v
|
113 |
+
|
114 |
+
def run(x, steps):
|
115 |
+
if method == 'ddpm':
|
116 |
+
return sampling.sample(cfg_model_fn, x, steps, 1., {})
|
117 |
+
if method == 'ddim':
|
118 |
+
return sampling.sample(cfg_model_fn, x, steps, eta, {})
|
119 |
+
if method == 'prk':
|
120 |
+
return sampling.prk_sample(cfg_model_fn, x, steps, {})
|
121 |
+
if method == 'plms':
|
122 |
+
return sampling.plms_sample(cfg_model_fn, x, steps, {})
|
123 |
+
if method == 'pie':
|
124 |
+
return sampling.pie_sample(cfg_model_fn, x, steps, {})
|
125 |
+
if method == 'plms2':
|
126 |
+
return sampling.plms2_sample(cfg_model_fn, x, steps, {})
|
127 |
+
assert False
|
128 |
+
|
129 |
+
batch_size = n
|
130 |
+
x = torch.randn([n, n_ch, side_y, side_x], device=device)
|
131 |
+
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
132 |
+
steps = utils.get_spliced_ddpm_cosine_schedule(t)
|
133 |
+
pil_ims = []
|
134 |
+
for i in trange(0, n, batch_size):
|
135 |
+
cur_batch_size = min(n - i, batch_size)
|
136 |
+
out_latents = run(x[i:i+cur_batch_size], steps)
|
137 |
+
outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
|
138 |
+
for j, out in enumerate(outs):
|
139 |
+
pil_ims.append(utils.to_pil_image(out))
|
140 |
+
|
141 |
+
return pil_ims
|
142 |
+
|
143 |
+
|
144 |
+
import gradio as gr
|
145 |
+
|
146 |
+
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
147 |
+
if seed == None :
|
148 |
+
seed = random.randint(0, 10000)
|
149 |
+
print( prompt, im_prompt, seed, n_steps)
|
150 |
+
prompts = [prompt]
|
151 |
+
im_prompts = []
|
152 |
+
if im_prompt != None:
|
153 |
+
im_prompts = [im_prompt]
|
154 |
+
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
155 |
+
return pil_ims[0]
|
156 |
+
|
157 |
+
iface = gr.Interface(fn=gen_ims,
|
158 |
+
inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
|
159 |
+
#gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
|
160 |
+
gr.inputs.Textbox(label="Text prompt"),
|
161 |
+
gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
|
162 |
+
#gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
|
163 |
+
],
|
164 |
+
outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
|
165 |
+
examples=[
|
166 |
+
["Impressionism, oil on canvas"],
|
167 |
+
["Futurism, in the style of Wassily Kandinsky"],
|
168 |
+
["Art Nouveau, in the style of John Singer Sargent"],
|
169 |
+
["Surrealism, in the style of Edgar Degas"],
|
170 |
+
["Expressionism, in the style of Wassily Kandinsky"],
|
171 |
+
["Futurism, in the style of Egon Schiele"],
|
172 |
+
["Neoclassicism, in the style of Gustav Klimt"],
|
173 |
+
["Cubism, in the style of Gustav Klimt"],
|
174 |
+
["Op Art, in the style of Marc Chagall"],
|
175 |
+
["Romanticism, in the style of M.C. Escher"],
|
176 |
+
["Futurism, in the style of M.C. Escher"],
|
177 |
+
["Abstract Art, in the style of M.C. Escher"],
|
178 |
+
["Mannerism, in the style of Paul Klee"],
|
179 |
+
["Romanesque Art, in the style of Leonardo da Vinci"],
|
180 |
+
["High Renaissance, in the style of Rembrandt"],
|
181 |
+
["Magic Realism, in the style of Gustave Dore"],
|
182 |
+
["Realism, in the style of Jean-Michel Basquiat"],
|
183 |
+
["Art Nouveau, in the style of Paul Gauguin"],
|
184 |
+
["Avant-garde, in the style of Pierre-Auguste Renoir"],
|
185 |
+
["Baroque, in the style of Edward Hopper"],
|
186 |
+
["Post-Impressionism, in the style of Wassily Kandinsky"],
|
187 |
+
["Naturalism, in the style of Rene Magritte"],
|
188 |
+
["Constructivism, in the style of Paul Cezanne"],
|
189 |
+
["Abstract Expressionism, in the style of Henri Matisse"],
|
190 |
+
["Pop Art, in the style of Vincent van Gogh"],
|
191 |
+
["Futurism, in the style of Wassily Kandinsky"],
|
192 |
+
["Futurism, in the style of Zdzislaw Beksinski"],
|
193 |
+
['Surrealism, in the style of Salvador Dali'],
|
194 |
+
["Aaron Wacker, oil on canvas"]
|
195 |
+
],
|
196 |
+
title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:',
|
197 |
+
description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
|
198 |
+
article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..'
|
199 |
+
|
200 |
+
)
|
201 |
+
iface.launch(enable_queue=True) # , debug=True for colab debugging
|