import argparse import torch import gradio as gr from torchvision import transforms from runner import MaskGIT import numpy as np import random import torchvision.utils as vutils class Args(argparse.Namespace): data_folder = "" vqgan_folder = "pretrained_maskgit/VQGAN" writer_log = "" data = "" mask_value = 1024 seed = 1 channel = 3 num_workers = 0 iter = 0 global_epoch = 0 lr = 1e-4 drop_label = 0.1 resume = True device = "cpu" print(device) debug = True test_only = False is_master = True is_multi_gpus = False vit_size = "base" vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth" img_size = 256 patch_size = 256 // 16 def set_seed(seed): if seed > 0: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.enable = False torch.backends.cudnn.deterministic = True args = Args() maskgit = MaskGIT(args) # Function to perform image synthesis def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1): # Perform image synthesis using your model set_seed(seed) with torch.no_grad(): labels = [cls] * nb_img labels = torch.LongTensor(labels).to(args.device) gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w, randomize="linear", r_temp=r_temp, sched_mode="arccos", step=step)[0] # Post-process the output image (adjust based on your needs) output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True)) return output_image # Gradio Interface app = gr.Interface( fn=synthesize_image, inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16), gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)], outputs=gr.Image(), title="Image Synthesis using MaskGIT", ) # Launch the Gradio app app.launch(share=True)