Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
from robustness.datasets import ImageNet | |
from robustness.attacker import AttackerModel | |
from timm.models import create_model | |
from torchvision import transforms | |
from robustness.tools.label_maps import CLASS_DICT | |
from src.utils import * | |
from torchvision import transforms | |
import gradio as gr | |
import os | |
from PIL import Image | |
DICT_CLASSES = {'lake':955, | |
'castle':483, | |
'library':624, | |
'dog':235, | |
'cat':285, | |
'people':842 #trunks | |
} | |
IMG_MAX_SIZE = 256 | |
ARCH = 'crossvit_18_dagger_408' | |
ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt' | |
CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]), | |
transforms.ToTensor()]) | |
DEVICE = 'cuda' | |
def load_model(robust = True): | |
test_image = Image.open('samples/test.png') | |
ds = CustomArt(test_image,CUSTOM_TRANSFORMS) | |
model = create_model(ARCH,pretrained = True).to(DEVICE) | |
if robust: | |
print("Load Robust Model") | |
checkpoint = torch.load(ARCH_PATH,map_location = DEVICE) | |
model.load_state_dict(checkpoint['state_dict'],strict = True) | |
model = RobustModel(model).to(DEVICE) | |
model = AttackerModel(model, ds).to(DEVICE) | |
model = model.eval() | |
del test_image,ds | |
return model | |
def gradio_fn(image_input,radio_steps,radio_class,radio_robust): | |
model = load_model(radio_robust) | |
kwargs = { | |
'constraint':'2', # L2 attack | |
'eps': 300, | |
'step_size': 1, | |
'iterations': int(radio_steps), | |
'targeted': True, | |
'do_tqdm': True, | |
'device': DEVICE | |
} | |
# Define the target and the image | |
target = torch.tensor([int(DICT_CLASSES[radio_class])]).to(DEVICE) | |
image = Image.fromarray(image_input) | |
image = CUSTOM_TRANSFORMS(image).to(DEVICE) | |
image = torch.unsqueeze(image, dim=0) | |
_, im_adv = model(image, target, make_adv=True, **kwargs) | |
im_adv = im_adv.squeeze(dim = 0).permute(1,2,0).cpu().numpy() | |
return im_adv | |
if __name__ == '__main__': | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# Art Adversarial Attack") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
# Radio Steps Adversarial attack | |
radio_steps = gr.Radio([10,500,1000,1500,2000],value = 500,label="# Attack Steps") | |
# Radio Targeted attack | |
radio_class = gr.Radio(list(DICT_CLASSES.keys()), | |
value = list(DICT_CLASSES.keys())[0], | |
label="Target Class") | |
radio_robust = gr.Radio([True,False],value = True,label="Robust Model") | |
# Image | |
with gr.Row(): | |
image_input = gr.Image(label="Input Image") | |
with gr.Row(): | |
calculate_button = gr.Button("Compute") | |
with gr.Column(): | |
target_image = gr.Image(label="Art Image") | |
calculate_button.click(fn = gradio_fn, | |
inputs = [image_input,radio_steps,radio_class,radio_robust], | |
outputs = target_image) | |
demo.launch(debug = True) | |