import torch import torch.nn as nn from robustness.datasets import ImageNet from robustness.attacker import AttackerModel from timm.models import create_model from torchvision import transforms from robustness.tools.label_maps import CLASS_DICT from src.utils import * from torchvision import transforms import gradio as gr import os from PIL import Image DICT_CLASSES = {'lake':955, 'castle':483, 'library':624} IMG_MAX_SIZE = 256 ARCH = 'crossvit_18_dagger_408' ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt' CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]), transforms.ToTensor()]) DEVICE = 'cpu' def load_model(robust = True): test_image = Image.open('samples/test.png') ds = CustomArt(test_image,CUSTOM_TRANSFORMS) model = create_model(ARCH,pretrained = True).to(DEVICE) if robust: print("Load Robust Model") checkpoint = torch.load(ARCH_PATH,map_location = DEVICE) model.load_state_dict(checkpoint['state_dict'],strict = True) model = RobustModel(model).to(DEVICE) model = AttackerModel(model, ds).to(DEVICE) model = model.eval() del test_image,ds return model def gradio_fn(image_input,radio_steps,radio_class,radio_robust): model = load_model(radio_robust) kwargs = { 'constraint':'2', # L2 attack 'eps': 300, 'step_size': 1, 'iterations': int(radio_steps), 'targeted': True, 'do_tqdm': True, 'device': DEVICE } # Define the target and the image target = torch.tensor([int(DICT_CLASSES[radio_class])]).to(DEVICE) image = Image.fromarray(image_input) image = CUSTOM_TRANSFORMS(image).to(DEVICE) image = torch.unsqueeze(image, dim=0) _, im_adv = model(image, target, make_adv=True, **kwargs) im_adv = im_adv.squeeze(dim = 0).permute(1,2,0).cpu().numpy() return im_adv if __name__ == '__main__': demo = gr.Blocks() with demo: gr.Markdown("# Art Adversarial Attack") with gr.Row(): with gr.Column(): with gr.Row(): # Radio Steps Adversarial attack radio_steps = gr.Radio([10,500,1000,1500,2000],value = 500,label="# Attack Steps") # Radio Targeted attack radio_class = gr.Radio(list(DICT_CLASSES.keys()), value = list(DICT_CLASSES.keys())[0], label="Target Class") radio_robust = gr.Radio([True,False],value = True,label="Robust Model") # Image with gr.Row(): image_input = gr.Image(label="Input Image") with gr.Row(): calculate_button = gr.Button("Compute") with gr.Column(): target_image = gr.Image(label="Art Image") calculate_button.click(fn = gradio_fn, inputs = [image_input,radio_steps,radio_class,radio_robust], outputs = target_image) demo.launch(share = True,debug = True)