Maskgit-pytorch / app.py
llvictorll's picture
add gradio app
8513f87 verified
raw
history blame
2.11 kB
import argparse
import torch
import gradio as gr
from torchvision import transforms
from runner import MaskGIT
import numpy as np
import random
import torchvision.utils as vutils
class Args(argparse.Namespace):
data_folder = ""
vqgan_folder = "pretrained_maskgit/VQGAN"
writer_log = ""
data = ""
mask_value = 1024
seed = 1
channel = 3
num_workers = 0
iter = 0
global_epoch = 0
lr = 1e-4
drop_label = 0.1
resume = True
device = "cpu"
print(device)
debug = True
test_only = False
is_master = True
is_multi_gpus = False
vit_size = "base"
vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
img_size = 256
patch_size = 256 // 16
def set_seed(seed):
if seed > 0:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.enable = False
torch.backends.cudnn.deterministic = True
args = Args()
maskgit = MaskGIT(args)
# Function to perform image synthesis
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
# Perform image synthesis using your model
set_seed(seed)
with torch.no_grad():
labels = [cls] * nb_img
labels = torch.LongTensor(labels).to(args.device)
gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
randomize="linear", r_temp=r_temp, sched_mode="arccos",
step=step)[0]
# Post-process the output image (adjust based on your needs)
output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
return output_image
# Gradio Interface
app = gr.Interface(
fn=synthesize_image,
inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
outputs=gr.Image(),
title="Image Synthesis using MaskGIT",
)
# Launch the Gradio app
app.launch(share=True)