File size: 2,687 Bytes
49d1787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from torch import nn
from copy import deepcopy
from .base import Attacker, Empty
from torch.cuda import amp
from tqdm import tqdm

class PGD(Attacker):
    def __init__(self, model, img_transform=(lambda x:x, lambda x:x), use_amp=False):
        super().__init__(model, img_transform)
        self.use_amp=use_amp
        self.call_back=None
        self.img_loader=None
        self.img_hook=None

        self.scaler = amp.GradScaler(enabled=use_amp)

    def set_para(self, eps=8, alpha=lambda:8, iters=20, **kwargs):
        super().set_para(eps=eps, alpha=alpha, iters=iters, **kwargs)

    def set_call_back(self, call_back):
        self.call_back=call_back

    def set_img_loader(self, img_loader):
        self.img_loader=img_loader

    def step(self, images, labels, loss):
        with amp.autocast(enabled=self.use_amp):
            images.requires_grad = True
            outputs = self.model(images).logits

            self.model.zero_grad()
            cost = loss(outputs, labels)#+outputs[2].view(-1)[0]*0+outputs[1].view(-1)[0]*0+outputs[0].view(-1)[0]*0 #support DDP

        self.scaler.scale(cost).backward()

        adv_images = (images + self.alpha() * images.grad.sign()).detach_()
        eta = torch.clamp(adv_images - self.ori_images, min=-self.eps, max=self.eps)
        images = self.img_transform[0](torch.clamp(self.img_transform[1](self.ori_images + eta), min=0, max=1).detach_())

        return images

    def set_data(self, images, labels):
        self.ori_images = deepcopy(images)
        self.images = images
        self.labels = labels

    def __iter__(self):
        self.atk_step=0
        return self

    def __next__(self):
        self.atk_step += 1
        if self.atk_step>self.iters:
            raise StopIteration

        with self.model.no_sync() if isinstance(self.model, nn.parallel.DistributedDataParallel) else Empty():
            self.model.eval()

            self.images = self.forward(self, self.images, self.labels)

            self.model.zero_grad()
            self.model.train()

        return self.ori_images, self.images.detach(), self.labels

    def attack(self, images, labels):
        #images = deepcopy(images)
        self.ori_images = deepcopy(images)

        for i in tqdm(range(self.iters)):
            self.model.eval()

            images = self.forward(self, images, labels)

            self.model.zero_grad()
            self.model.train()
            if self.call_back:
                self.call_back(self.ori_images, images.detach(), labels)

            if self.img_hook is not None:
                images=self.img_hook(self.ori_images, images.detach())

        return images