File size: 2,108 Bytes
8513f87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import argparse
import torch
import gradio as gr
from torchvision import transforms
from runner import MaskGIT
import numpy as np
import random
import torchvision.utils as vutils
class Args(argparse.Namespace):
data_folder = ""
vqgan_folder = "pretrained_maskgit/VQGAN"
writer_log = ""
data = ""
mask_value = 1024
seed = 1
channel = 3
num_workers = 0
iter = 0
global_epoch = 0
lr = 1e-4
drop_label = 0.1
resume = True
device = "cpu"
print(device)
debug = True
test_only = False
is_master = True
is_multi_gpus = False
vit_size = "base"
vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
img_size = 256
patch_size = 256 // 16
def set_seed(seed):
if seed > 0:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.enable = False
torch.backends.cudnn.deterministic = True
args = Args()
maskgit = MaskGIT(args)
# Function to perform image synthesis
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
# Perform image synthesis using your model
set_seed(seed)
with torch.no_grad():
labels = [cls] * nb_img
labels = torch.LongTensor(labels).to(args.device)
gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
randomize="linear", r_temp=r_temp, sched_mode="arccos",
step=step)[0]
# Post-process the output image (adjust based on your needs)
output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))
return output_image
# Gradio Interface
app = gr.Interface(
fn=synthesize_image,
inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
outputs=gr.Image(),
title="Image Synthesis using MaskGIT",
)
# Launch the Gradio app
app.launch(share=True)
|