File size: 4,513 Bytes
49d1787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8d7eff
49d1787
 
 
 
 
 
 
f8d7eff
 
49d1787
 
 
 
 
 
 
 
 
 
 
 
 
f8d7eff
49d1787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import os
from transformers import BeitFeatureExtractor, BeitForImageClassification
from PIL import Image

from torchvision.utils import save_image
import torch.nn.functional as F
from torchvision import transforms

from attacker import *
from torch.nn import CrossEntropyLoss

import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def make_args():
    parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training')

    parser.add_argument('inputs', type=str)
    parser.add_argument('--out_dir', type=str, default='./output')
    parser.add_argument('--target', type=str, default='auto', help='[auto, ai, human]')
    parser.add_argument('--eps', type=float, default=8/8, help='Noise intensity ')
    parser.add_argument('--step_size', type=float, default=1.087313/8, help='Attack step size')
    parser.add_argument('--steps', type=int, default=20, help='Attack step count')

    parser.add_argument('--test_atk', action='store_true')

    return parser.parse_args()

class Attacker:
    def __init__(self, args, pgd_callback):
        self.args=args
        os.makedirs(args.out_dir, exist_ok=True)

        print('正在加载模型...')
        self.feature_extractor = BeitFeatureExtractor.from_pretrained('saltacc/anime-ai-detect')
        self.model = BeitForImageClassification.from_pretrained('saltacc/anime-ai-detect').to(device)
        print('加载完毕')

        if args.target=='ai': #攻击成被识别为AI
            self.target = torch.tensor([1]).to(device)
        elif args.target=='human':
            self.target = torch.tensor([0]).to(device)

        dataset_mean_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).to(device)
        dataset_std_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).to(device)
        self.pgd = PGD(self.model, img_transform=(lambda x: (x - dataset_mean_t) / dataset_std_t, lambda x: x * dataset_std_t + dataset_mean_t))
        self.pgd.set_para(eps=(args.eps * 2) / 255, alpha=lambda: (args.step_size * 2) / 255, iters=args.steps)
        self.pgd.set_loss(CrossEntropyLoss())
        self.pgd.set_call_back(pgd_callback)

    def save_image(self, image, noise, img_name):
        # 缩放图片只缩放噪声
        W, H = image.size
        noise = F.interpolate(noise, size=(H, W), mode='bicubic')
        img_save = transforms.ToTensor()(image) + noise
        save_image(img_save, os.path.join(self.args.out_dir, f'{img_name[:img_name.rfind(".")]}_atk.png'))

    def attack_(self, image):
        inputs = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].to(device)

        if self.args.target == 'auto':
            with torch.no_grad():
                outputs = self.model(inputs)
                logits = outputs.logits
                cls = logits.argmax(-1).item()
                target = torch.tensor([cls]).to(device)
        else:
            target = self.target

        if self.args.test_atk:
            self.test_image(inputs, 'before attack')

        atk_img = self.pgd.attack(inputs, target)

        noise = self.pgd.img_transform[1](atk_img).detach().cpu() - self.pgd.img_transform[1](inputs).detach().cpu()

        if self.args.test_atk:
            self.test_image(atk_img, 'after attack')

        return atk_img, noise

    def attack_one(self, path):
        image = Image.open(path).convert('RGB')
        atk_img, noise = self.attack_(image)
        self.save_image(image, noise, os.path.basename(path))

    def attack(self, path):
        count=0
        if os.path.isdir(path):
            img_list=[os.path.join(path, x) for x in os.listdir(path)]
            for img in img_list:
                if (img.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff'))):
                    self.attack_one(img)
                    count+=1
        else:
            if (path.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff'))):
                self.attack_one(path)
                count += 1
        print(f'总共攻击{count}张图像')

    @torch.no_grad()
    def test_image(self, img, pre_fix):
        outputs = self.model(img)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        print(pre_fix, "class:", self.model.config.id2label[predicted_class_idx], 'logits:', logits)

if __name__ == '__main__':
    args=make_args()
    attacker = Attacker(args)
    attacker.attack(args.inputs)