Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,11 @@ from glide_text2im.model_creation import (
|
|
20 |
model_and_diffusion_defaults_upsampler
|
21 |
)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
23 |
# This notebook supports both CPU and GPU.
|
24 |
# On CPU, generating one sample may take on the order of 20 minutes.
|
25 |
# On a GPU, it should be under a minute.
|
@@ -193,14 +198,94 @@ def compose_language_descriptions(prompt):
|
|
193 |
out_img = np.array(out_img.data.to('cpu'))
|
194 |
return out_img
|
195 |
|
196 |
-
#
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
import gradio as gr
|
203 |
gr.Interface(title='Compositional Visual Generation with Composable Diffusion Models',
|
204 |
-
description='Demo for Composable Diffusion (~20s per example)
|
205 |
-
fn=
|
206 |
|
|
|
20 |
model_and_diffusion_defaults_upsampler
|
21 |
)
|
22 |
|
23 |
+
from composable_diffusion.download import download_model
|
24 |
+
from composable_diffusion.model_creation import create_model_and_diffusion as create_model_and_diffusion_for_clevr
|
25 |
+
from composable_diffusion.model_creation import model_and_diffusion_defaults as model_and_diffusion_defaults_for_clevr
|
26 |
+
|
27 |
+
|
28 |
# This notebook supports both CPU and GPU.
|
29 |
# On CPU, generating one sample may take on the order of 20 minutes.
|
30 |
# On a GPU, it should be under a minute.
|
|
|
198 |
out_img = np.array(out_img.data.to('cpu'))
|
199 |
return out_img
|
200 |
|
201 |
+
# create model for CLEVR Objects
|
202 |
+
timestep_respacing = 100
|
203 |
+
clevr_options = model_and_diffusion_defaults_for_clevr()
|
204 |
+
|
205 |
+
flags = {
|
206 |
+
"image_size": 128,
|
207 |
+
"num_channels": 192,
|
208 |
+
"num_res_blocks": 2,
|
209 |
+
"learn_sigma": True,
|
210 |
+
"use_scale_shift_norm": False,
|
211 |
+
"raw_unet": True,
|
212 |
+
"noise_schedule": "squaredcos_cap_v2",
|
213 |
+
"rescale_learned_sigmas": False,
|
214 |
+
"rescale_timesteps": False,
|
215 |
+
"num_classes": '2',
|
216 |
+
"dataset": "clevr_pos",
|
217 |
+
"use_fp16": has_cuda,
|
218 |
+
"timestep_respacing": str(timestep_respacing)
|
219 |
+
}
|
220 |
+
|
221 |
+
for key, val in flags.items():
|
222 |
+
clevr_options[key] = val
|
223 |
+
|
224 |
+
clevr_model, clevr_diffusion = create_model_and_diffusion_for_clevr(**clevr_options)
|
225 |
+
clevr_model.eval()
|
226 |
+
if has_cuda:
|
227 |
+
clevr_model.convert_to_fp16()
|
228 |
+
|
229 |
+
clevr_model.to(device)
|
230 |
+
clevr_model.load_state_dict(th.load(download_model('clevr_pos'), device))
|
231 |
+
|
232 |
+
def compose_clevr_objects(coordinates):
|
233 |
+
coordinates = [[float(x.split(',')[0].strip()), float(x.split(',')[1].strip())]
|
234 |
+
for x in coordinates.split('|')]
|
235 |
+
coordinates += [[-1, -1]] # add unconditional score label
|
236 |
+
batch_size = 1
|
237 |
+
guidance_scale = 10
|
238 |
+
|
239 |
+
def model_fn(x_t, ts, **kwargs):
|
240 |
+
half = x_t[:1]
|
241 |
+
combined = th.cat([half] * kwargs['y'].size(0), dim=0)
|
242 |
+
model_out = model(combined, ts, **kwargs)
|
243 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
244 |
+
masks = kwargs.get('masks')
|
245 |
+
cond_eps = eps[masks].mean(dim=0, keepdim=True)
|
246 |
+
uncond_eps = eps[~masks].mean(dim=0, keepdim=True)
|
247 |
+
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
248 |
+
eps = th.cat([half_eps] * x_t.size(0), dim=0)
|
249 |
+
return th.cat([eps, rest], dim=1)
|
250 |
+
|
251 |
+
masks = [True] * (len(coordinates) - 1) + [False]
|
252 |
+
model_kwargs = dict(
|
253 |
+
y=th.tensor(coordinates, dtype=th.float, device=device),
|
254 |
+
masks=th.tensor(masks, dtype=th.bool, device=device)
|
255 |
+
)
|
256 |
+
|
257 |
+
def sample(coordinates):
|
258 |
+
samples = diffusion.p_sample_loop(
|
259 |
+
model_fn,
|
260 |
+
(len(coordinates), 3, options["image_size"], options["image_size"]),
|
261 |
+
device=device,
|
262 |
+
clip_denoised=True,
|
263 |
+
progress=True,
|
264 |
+
model_kwargs=model_kwargs,
|
265 |
+
cond_fn=None,
|
266 |
+
)[:batch_size]
|
267 |
+
|
268 |
+
return samples
|
269 |
+
|
270 |
+
samples = sample(coordinates)
|
271 |
+
out_img = samples[0].permute(1,2,0)
|
272 |
+
out_img = (out_img+1)/2
|
273 |
+
out_img = np.array(out_img.data.to('cpu'))
|
274 |
+
return out_img
|
275 |
+
|
276 |
+
|
277 |
+
def compose(prompt, ver):
|
278 |
+
if ver == 'GLIDE':
|
279 |
+
return compose_language_descriptions(prompt)
|
280 |
+
else:
|
281 |
+
return compose_clevr_objects(prompt)
|
282 |
+
|
283 |
+
examples_1 = ['a camel | a forest', 'A cloudy blue sky | A mountain in the horizon | Cherry Blossoms in front of the mountain']
|
284 |
+
examples_2 = ['0.1, 0.5 | 0.3, 0.5 | 0.5, 0.5 | 0.7, 0.5 | 0.9, 0.5']
|
285 |
+
examples = [[examples_1, 'GLIDE'], [examples_2, 'CLEVR Objects']]
|
286 |
|
287 |
import gradio as gr
|
288 |
gr.Interface(title='Compositional Visual Generation with Composable Diffusion Models',
|
289 |
+
description='<p>Demo for Composable Diffusion (~20s per example)</p><p>See more information from our <a href="https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/">Project Page</a>.</p><ul><li>One version is based on the released <a href="https://github.com/openai/glide-text2im">GLIDE</a> for composing natural language description.</li><li>Another is based on our pre-trained CLEVR Object Model for composing objects. <br>(<b>Note</b>: We recommend using <b><i>x</i></b> in range <b><i>[0.1, 0.9]</i></b> and <b><i>y</i></b> in range <b><i>[0.25, 0.7]</i></b>, since the training dataset labels are in given ranges.).</li></ul><p>When composing multiple sentences, use `|` as the delimiter, see given examples below.</p>',
|
290 |
+
fn=compose, inputs=['text', gr.inputs.Radio(['GLIDE','CLEVR Objects'], type="value", default='GLIDE', label='version')], outputs='image', examples=examples).launch();
|
291 |
|