File size: 2,108 Bytes
8513f87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import torch
import gradio as gr
from torchvision import transforms
from runner import MaskGIT
import numpy as np
import random
import torchvision.utils as vutils


class Args(argparse.Namespace):
    data_folder = ""
    vqgan_folder = "pretrained_maskgit/VQGAN"
    writer_log = ""
    data = ""
    mask_value = 1024
    seed = 1
    channel = 3
    num_workers = 0
    iter = 0
    global_epoch = 0
    lr = 1e-4
    drop_label = 0.1
    resume = True
    device = "cpu"
    print(device)
    debug = True
    test_only = False
    is_master = True
    is_multi_gpus = False
    vit_size = "base"
    vit_folder = "pretrained_maskgit/MaskGIT/MaskGIT_ImageNet_256.pth"
    img_size = 256
    patch_size = 256 // 16


def set_seed(seed):
    if seed > 0:
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.backends.cudnn.enable = False
        torch.backends.cudnn.deterministic = True

args = Args()
maskgit = MaskGIT(args)


# Function to perform image synthesis
def synthesize_image(cls, sm_temp=1, w=3, r_temp=4.5, step=8, seed=1, nb_img=1):
    # Perform image synthesis using your model
    set_seed(seed)
    with torch.no_grad():
        labels = [cls] * nb_img
        labels = torch.LongTensor(labels).to(args.device)
        gen_sample = maskgit.sample(nb_sample=labels.size(0), labels=labels, sm_temp=sm_temp, w=w,
                                    randomize="linear", r_temp=r_temp, sched_mode="arccos",
                                    step=step)[0]

    # Post-process the output image (adjust based on your needs)
    output_image = transforms.ToPILImage()(vutils.make_grid(gen_sample, nrow=2, padding=0, normalize=True))

    return output_image


# Gradio Interface
app = gr.Interface(
    fn=synthesize_image,
    inputs=[gr.Number(31), gr.Number(1.3), gr.Number(25), gr.Number(4.5), gr.Number(16),
            gr.Slider(0, 1000, 60), gr.Number(1, maximum=4)],
    outputs=gr.Image(),
    title="Image Synthesis using MaskGIT",
)

# Launch the Gradio app
app.launch(share=True)