import torch
import torch.nn as nn
from robustness.datasets import ImageNet
from robustness.attacker import AttackerModel
from timm.models import create_model
from torchvision import transforms
from robustness.tools.label_maps import CLASS_DICT
from src.utils import *
from torchvision import transforms
import gradio as gr
import os
from PIL import Image

DICT_CLASSES = {'lake':955,
                'castle':483,
                'library':624}
IMG_MAX_SIZE = 256
ARCH = 'crossvit_18_dagger_408' 
ARCH_PATH = './checkpoints/robust_crossvit_18_dagger_408.pt'
CUSTOM_TRANSFORMS = transforms.Compose([transforms.Resize([IMG_MAX_SIZE,IMG_MAX_SIZE]),
                                        transforms.ToTensor()])
DEVICE = 'cpu'


def load_model(robust = True):
    test_image = Image.open('samples/test.png')
    ds = CustomArt(test_image,CUSTOM_TRANSFORMS)
    model = create_model(ARCH,pretrained = True).to(DEVICE)
    if robust:
        print("Load Robust Model")
        checkpoint = torch.load(ARCH_PATH,map_location = DEVICE)
        model.load_state_dict(checkpoint['state_dict'],strict = True)
    model = RobustModel(model).to(DEVICE)    
    model = AttackerModel(model, ds).to(DEVICE)
    model = model.eval()
    del test_image,ds
    return model


def gradio_fn(image_input,radio_steps,radio_class,radio_robust):
    model = load_model(radio_robust)
    kwargs = {
            'constraint':'2', # L2 attack
            'eps': 300,
            'step_size': 1,
            'iterations': int(radio_steps),
            'targeted': True,
            'do_tqdm': True,
            'device': DEVICE
            }
    # Define the target and the image
    target = torch.tensor([int(DICT_CLASSES[radio_class])]).to(DEVICE)
    image = Image.fromarray(image_input)
    image = CUSTOM_TRANSFORMS(image).to(DEVICE)
    image = torch.unsqueeze(image, dim=0)
    _, im_adv = model(image, target, make_adv=True, **kwargs)
    im_adv = im_adv.squeeze(dim = 0).permute(1,2,0).cpu().numpy()
    return im_adv


if __name__ == '__main__':
    demo = gr.Blocks()
    with demo:
        gr.Markdown("# Art Adversarial Attack")
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    # Radio Steps Adversarial attack
                    radio_steps = gr.Radio([10,500,1000,1500,2000],value = 500,label="# Attack Steps")
                    # Radio Targeted attack
                    radio_class = gr.Radio(list(DICT_CLASSES.keys()),
                                           value = list(DICT_CLASSES.keys())[0],
                                           label="Target Class")
                    radio_robust = gr.Radio([True,False],value = True,label="Robust Model")
                    # Image
                with gr.Row():
                    image_input = gr.Image(label="Input Image")
                with gr.Row():
                    calculate_button = gr.Button("Compute")
            with gr.Column():
                target_image    = gr.Image(label="Art Image")
    
        calculate_button.click(fn = gradio_fn,
                               inputs = [image_input,radio_steps,radio_class,radio_robust],
                               outputs = target_image)
    demo.launch(share = True,debug = True)