In [1]:
import os
import sys
import math
import argparse
import clip
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image

sys.path.append(os.path.dirname(os.getcwd()))

from dalle.models import Dalle
from dalle.utils.utils import set_seed, clip_score

device = 'cuda:0'
model = Dalle.from_pretrained("minDALL-E/1.3B")
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)

model_clip.to(device=device)
model.to(device=device)

def sampling(prompt, top_k, softmax_temperature, seed, num_candidates=96, num_samples_for_display=36):
 # Setup
 n_row = int(math.sqrt(num_samples_for_display))
 n_col = int(math.sqrt(num_samples_for_display))
 set_seed(seed)
 
 # Sampling
 images = model.sampling(prompt=prompt,
 top_k=top_k,
 top_p=None,
 softmax_temperature=softmax_temperature,
 num_candidates=num_candidates,
 device=device).cpu().numpy()
 images = np.transpose(images, (0, 2, 3, 1))

 # CLIP Re-ranking
 rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device)
 images = images[rank]
 
 images = images[:num_samples_for_display]
 fig = plt.figure(figsize=(8*n_row, 8*n_col))

 for i in range(num_samples_for_display):
 ax = fig.add_subplot(n_row, n_col, i+1)
 ax.imshow(images[i])
 ax.set_axis_off()

 plt.tight_layout()
 plt.show()

100%|█████████████████████████████████████| 4.72G/4.72G [02:04<00:00, 40.7MiB/s]
extracting: ./1.3B/tokenizer/bpe-16k-vocab.json (size:0MB): 100%|██████████| 7/7 [00:59<00:00, 8.51s/it]


/root/.cache/minDALL-E/1.3B/tokenizer successfully restored..
/root/.cache/minDALL-E/1.3B/stage1_last.ckpt successfully restored..


 0%| | 0.00/338M [00:00<?, ?iB/s]

/root/.cache/minDALL-E/1.3B/stage2_last.ckpt succesfully restored..


100%|███████████████████████████████████████| 338M/338M [00:09<00:00, 38.5MiB/s]


In [2]:
import ipywidgets as widgets
from IPython.display import display
from IPython.display import clear_output

output = widgets.Output()
plot_output = widgets.Output()

def btn_eventhandler(obj):
 output.clear_output()
 plot_output.clear_output()
 
 with output:
 print(f'SEED: {slider_seed.value}')
 print(f'Softmax Temperature: {slider_temp.value}')
 print(f'Top-K: {slider_topk.value}')
 print(f'Text prompt: {wd_text.value}')
 
 with plot_output:
 sampling(prompt=wd_text.value, top_k=slider_topk.value, softmax_temperature=slider_temp.value, seed=slider_seed.value)
 
slider_seed = widgets.IntSlider(
 min=0,
 max=1024,
 step=1,
 description='RND SEED: ',
 value=0
)
slider_topk = widgets.IntSlider(
 min=0,
 max=512,
 step=16,
 description='TOP-K:',
 value=256
)
slider_temp = widgets.FloatSlider(
 min=0.0,
 max=5.0,
 step=0.2,
 description='SOFTMAX TEMPERATURE:',
 value=1.0
)
wd_text = widgets.Text(
 value='A painting of a monkey with sunglasses in the frame',
 placeholder='Text prompt',
 description='String:',
 disabled=False
)

display(slider_seed)
display(slider_temp)
display(slider_topk)
display(wd_text)

btn = widgets.Button(description='Generate!')
display(btn)
btn.on_click(btn_eventhandler)

display(output)
display(plot_output)

IntSlider(value=0, description='RND SEED: ', max=1024)

FloatSlider(value=1.0, description='SOFTMAX TEMPERATURE:', max=5.0, step=0.2)

IntSlider(value=256, description='TOP-K:', max=512, step=16)

Text(value='A painting of a monkey with sunglasses in the frame', description='String:', placeholder='Text pro…

Button(description='Generate!', style=ButtonStyle())

Output()

Output()