|
import argparse |
|
import torch |
|
import gradio as gr |
|
from torchvision import transforms |
|
from runner import MaskGIT |
|
import numpy as np |
|
import random |
|
import torchvision.utils as vutils |
|
|
|
|
|
class Args(argparse.Namespace): |
|
data_folder = "" |
|
vqgan_folder = "pretrained_maskgit/VQGAN" |
|
writer_log = "" |
|
data = "" |
|
mask_value = 1024 |
|
seed = 1 |
|
channel = 3 |
|
num_workers = 0 |
|
iter = 0 |
|
global_epoch = 0 |
|
lr = 1e-4 |
|
drop_label = 0.1 |
|
resume = True |
|
device = "cpu" |
|
print(device) |
|
debug = True |
|
test_only = False |
|
is_master = True |
|
is_multi_gpus = False |
|
vit_size = "base" |
|
vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth" |
|
img_size = 256 |
|
patch_size = 256 // 16 |
|
|
|
|
|
def set_seed(seed): |
|
if seed > 0: |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.backends.cudnn.enable = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
args = Args() |
|
maskgit = MaskGIT(args) |
|
|
|
|
|
|
|
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1): |
|
|
|
set_seed(seed) |
|
with torch.no_grad(): |
|
labels = [cls] * nb_img |
|
labels = torch.LongTensor(labels).to(args.device) |
|
gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w, |
|
randomize="linear", r_temp=r_temp, sched_mode="arccos", |
|
step=step)[0] |
|
|
|
|
|
output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True)) |
|
|
|
return output_image |
|
|
|
|
|
|
|
app = gr.Interface( |
|
fn=synthesize_image, |
|
inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16), |
|
gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)], |
|
outputs=gr.Image(), |
|
title="Image Synthesis using MaskGIT", |
|
) |
|
|
|
|
|
app.launch(share=True) |
|
|
|
|