Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
|
4 |
+
os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
from functools import partial
|
8 |
+
from pathlib import Path
|
9 |
+
import sys
|
10 |
+
sys.path.append('./cloob-latent-diffusion')
|
11 |
+
sys.path.append('./cloob-latent-diffusion/cloob-training')
|
12 |
+
sys.path.append('./cloob-latent-diffusion/latent-diffusion')
|
13 |
+
sys.path.append('./cloob-latent-diffusion/taming-transformers')
|
14 |
+
sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from PIL import Image
|
17 |
+
import torch
|
18 |
+
from torch import nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from torchvision import transforms
|
21 |
+
from torchvision.transforms import functional as TF
|
22 |
+
from tqdm import trange
|
23 |
+
from CLIP import clip
|
24 |
+
from cloob_training import model_pt, pretrained
|
25 |
+
import ldm.models.autoencoder
|
26 |
+
from diffusion import sampling, utils
|
27 |
+
import train_latent_diffusion as train
|
28 |
+
from huggingface_hub import hf_hub_url, cached_download
|
29 |
+
import random
|
30 |
+
|
31 |
+
# Download the model files
|
32 |
+
checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
|
33 |
+
ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
|
34 |
+
ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
|
35 |
+
|
36 |
+
# Define a few utility functions
|
37 |
+
|
38 |
+
def parse_prompt(prompt, default_weight=3.):
|
39 |
+
if prompt.startswith('http://') or prompt.startswith('https://'):
|
40 |
+
vals = prompt.rsplit(':', 2)
|
41 |
+
vals = [vals[0] + ':' + vals[1], *vals[2:]]
|
42 |
+
else:
|
43 |
+
vals = prompt.rsplit(':', 1)
|
44 |
+
vals = vals + ['', default_weight][len(vals):]
|
45 |
+
return vals[0], float(vals[1])
|
46 |
+
|
47 |
+
|
48 |
+
def resize_and_center_crop(image, size):
|
49 |
+
fac = max(size[0] / image.size[0], size[1] / image.size[1])
|
50 |
+
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
|
51 |
+
return TF.center_crop(image, size[::-1])
|
52 |
+
|
53 |
+
|
54 |
+
# Load the models
|
55 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
56 |
+
print('Using device:', device)
|
57 |
+
print('loading models')
|
58 |
+
# autoencoder
|
59 |
+
ae_config = OmegaConf.load(ae_config_path)
|
60 |
+
ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
|
61 |
+
ae_model.eval().requires_grad_(False).to(device)
|
62 |
+
ae_model.load_state_dict(torch.load(ae_model_path))
|
63 |
+
n_ch, side_y, side_x = 4, 32, 32
|
64 |
+
|
65 |
+
# diffusion model
|
66 |
+
model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
|
67 |
+
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
|
68 |
+
model = model.to(device).eval().requires_grad_(False)
|
69 |
+
|
70 |
+
# CLOOB
|
71 |
+
cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
|
72 |
+
cloob = model_pt.get_pt_model(cloob_config)
|
73 |
+
checkpoint = pretrained.download_checkpoint(cloob_config)
|
74 |
+
cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
|
75 |
+
cloob.eval().requires_grad_(False).to(device)
|
76 |
+
|
77 |
+
|
78 |
+
# The key function: returns a list of n PIL images
|
79 |
+
def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
|
80 |
+
method='plms', eta=None):
|
81 |
+
zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
|
82 |
+
target_embeds, weights = [zero_embed], []
|
83 |
+
|
84 |
+
for prompt in prompts:
|
85 |
+
txt, weight = parse_prompt(prompt)
|
86 |
+
target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
|
87 |
+
weights.append(weight)
|
88 |
+
|
89 |
+
for prompt in images:
|
90 |
+
path, weight = parse_prompt(prompt)
|
91 |
+
img = Image.open(utils.fetch(path)).convert('RGB')
|
92 |
+
clip_size = cloob.config['image_encoder']['image_size']
|
93 |
+
img = resize_and_center_crop(img, (clip_size, clip_size))
|
94 |
+
batch = TF.to_tensor(img)[None].to(device)
|
95 |
+
embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
|
96 |
+
target_embeds.append(embed)
|
97 |
+
weights.append(weight)
|
98 |
+
|
99 |
+
weights = torch.tensor([1 - sum(weights), *weights], device=device)
|
100 |
+
|
101 |
+
torch.manual_seed(seed)
|
102 |
+
|
103 |
+
def cfg_model_fn(x, t):
|
104 |
+
n = x.shape[0]
|
105 |
+
n_conds = len(target_embeds)
|
106 |
+
x_in = x.repeat([n_conds, 1, 1, 1])
|
107 |
+
t_in = t.repeat([n_conds])
|
108 |
+
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
|
109 |
+
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
|
110 |
+
v = vs.mul(weights[:, None, None, None, None]).sum(0)
|
111 |
+
return v
|
112 |
+
|
113 |
+
def run(x, steps):
|
114 |
+
if method == 'ddpm':
|
115 |
+
return sampling.sample(cfg_model_fn, x, steps, 1., {})
|
116 |
+
if method == 'ddim':
|
117 |
+
return sampling.sample(cfg_model_fn, x, steps, eta, {})
|
118 |
+
if method == 'prk':
|
119 |
+
return sampling.prk_sample(cfg_model_fn, x, steps, {})
|
120 |
+
if method == 'plms':
|
121 |
+
return sampling.plms_sample(cfg_model_fn, x, steps, {})
|
122 |
+
if method == 'pie':
|
123 |
+
return sampling.pie_sample(cfg_model_fn, x, steps, {})
|
124 |
+
if method == 'plms2':
|
125 |
+
return sampling.plms2_sample(cfg_model_fn, x, steps, {})
|
126 |
+
assert False
|
127 |
+
|
128 |
+
batch_size = n
|
129 |
+
x = torch.randn([n, n_ch, side_y, side_x], device=device)
|
130 |
+
t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
|
131 |
+
steps = utils.get_spliced_ddpm_cosine_schedule(t)
|
132 |
+
pil_ims = []
|
133 |
+
for i in trange(0, n, batch_size):
|
134 |
+
cur_batch_size = min(n - i, batch_size)
|
135 |
+
out_latents = run(x[i:i+cur_batch_size], steps)
|
136 |
+
outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
|
137 |
+
for j, out in enumerate(outs):
|
138 |
+
pil_ims.append(utils.to_pil_image(out))
|
139 |
+
|
140 |
+
return pil_ims
|
141 |
+
|
142 |
+
|
143 |
+
import gradio as gr
|
144 |
+
|
145 |
+
def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
|
146 |
+
if seed == None :
|
147 |
+
seed = random.randint(0, 10000)
|
148 |
+
print( prompt, im_prompt, seed, n_steps)
|
149 |
+
prompts = [prompt]
|
150 |
+
im_prompts = []
|
151 |
+
if im_prompt != None:
|
152 |
+
im_prompts = [im_prompt]
|
153 |
+
pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
|
154 |
+
return pil_ims[0]
|
155 |
+
|
156 |
+
iface = gr.Interface(fn=gen_ims,
|
157 |
+
inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
|
158 |
+
#gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
|
159 |
+
gr.inputs.Textbox(label="Text prompt"),
|
160 |
+
gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
|
161 |
+
#gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
|
162 |
+
],
|
163 |
+
outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
|
164 |
+
examples=[["An iceberg, oil on canvas"],["A martian landscape, in the style of Monet"], ['A peaceful meadow, pastel crayons'], ["A painting of a vase of flowers"], ["A ship leaving the port in the summer, oil on canvas"]],
|
165 |
+
title='Generate art from text prompts :',
|
166 |
+
description="By typing a text prompt or providing an image prompt, and pressing submit you can generate images based on this prompt. The model was trained on images from the [WikiArt](https://huggingface.co/datasets/huggan/wikiart) dataset, comprised mostly of paintings.",
|
167 |
+
article = 'The model is a distilled version of a cloob-conditioned latent diffusion model fine-tuned on the WikiArt dataset. You can find more information on this model on the [model card](https://huggingface.co/huggan/distill-ccld-wa). According to the [Latent Diffusion paper](https://arxiv.org/abs/2112.10752): \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\".'
|
168 |
+
|
169 |
+
)
|
170 |
+
iface.launch(enable_queue=True) # , debug=True for colab debugging
|