Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| from robustness.datasets import ImageNet | |
| from robustness.attacker import AttackerModel | |
| from timm.models import create_model | |
| from torchvision import transforms | |
| from robustness.tools.label_maps import CLASS_DICT | |
| from src.utils import * | |
| from torchvision import transforms | |
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| DICT_CLASSES = {'lake':955, | |
| 'castle':483, | |
| 'library':624, | |
| 'dog':235, | |
| 'cat':285, | |
| 'people':842 #trunks | |
| } | |
| IMG_MAX_SIZE = 256 | |
| ARCH = 'crossvit_18_dagger_408' | |
| ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt' | |
| CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]), | |
| transforms.ToTensor()]) | |
| DEVICE = 'cpu' | |
| def load_model(robust = True): | |
| test_image = Image.open('samples/test.png') | |
| ds = CustomArt(test_image,CUSTOM_TRANSFORMS) | |
| model = create_model(ARCH,pretrained = True).to(DEVICE) | |
| if robust: | |
| print("Load Robust Model") | |
| checkpoint = torch.load(ARCH_PATH,map_location = DEVICE) | |
| model.load_state_dict(checkpoint['state_dict'],strict = True) | |
| model = RobustModel(model).to(DEVICE) | |
| model = AttackerModel(model, ds).to(DEVICE) | |
| model = model.eval() | |
| del test_image,ds | |
| return model | |
| def gradio_fn(image_input,radio_steps,radio_class,radio_robust): | |
| model = load_model(radio_robust) | |
| kwargs = { | |
| 'constraint':'2', # L2 attack | |
| 'eps': 300, | |
| 'step_size': 1, | |
| 'iterations': int(radio_steps), | |
| 'targeted': True, | |
| 'do_tqdm': True, | |
| 'device': DEVICE | |
| } | |
| # Define the target and the image | |
| target = torch.tensor([int(DICT_CLASSES[radio_class])]).to(DEVICE) | |
| image = Image.fromarray(image_input) | |
| image = CUSTOM_TRANSFORMS(image).to(DEVICE) | |
| image = torch.unsqueeze(image, dim=0) | |
| _, im_adv = model(image, target, make_adv=True, **kwargs) | |
| im_adv = im_adv.squeeze(dim = 0).permute(1,2,0).cpu().numpy() | |
| return im_adv | |
| if __name__ == '__main__': | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("# Art Adversarial Attack") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| # Radio Steps Adversarial attack | |
| radio_steps = gr.Radio([10,500,1000,1500,2000],value = 500,label="# Attack Steps") | |
| # Radio Targeted attack | |
| radio_class = gr.Radio(list(DICT_CLASSES.keys()), | |
| value = list(DICT_CLASSES.keys())[0], | |
| label="Target Class") | |
| radio_robust = gr.Radio([True,False],value = True,label="Robust Model") | |
| # Image | |
| with gr.Row(): | |
| image_input = gr.Image(label="Input Image") | |
| with gr.Row(): | |
| calculate_button = gr.Button("Compute") | |
| with gr.Column(): | |
| target_image = gr.Image(label="Art Image") | |
| calculate_button.click(fn = gradio_fn, | |
| inputs = [image_input,radio_steps,radio_class,radio_robust], | |
| outputs = target_image) | |
| demo.launch(share = True,debug = True) | |