Spaces:
Runtime error
Runtime error
import torch | |
import os | |
from transformers import BeitFeatureExtractor, BeitForImageClassification | |
from PIL import Image | |
from torchvision.utils import save_image | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from attacker import * | |
from torch.nn import CrossEntropyLoss | |
import argparse | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def make_args(): | |
parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training') | |
parser.add_argument('inputs', type=str) | |
parser.add_argument('--out_dir', type=str, default='./output') | |
parser.add_argument('--target', type=str, default='auto', help='[auto, ai, human]') | |
parser.add_argument('--eps', type=float, default=8/8, help='Noise intensity ') | |
parser.add_argument('--step_size', type=float, default=1.087313/8, help='Attack step size') | |
parser.add_argument('--steps', type=int, default=20, help='Attack step count') | |
parser.add_argument('--test_atk', action='store_true') | |
return parser.parse_args() | |
class Attacker: | |
def __init__(self, args, pgd_callback): | |
self.args=args | |
os.makedirs(args.out_dir, exist_ok=True) | |
print('正在加载模型...') | |
self.feature_extractor = BeitFeatureExtractor.from_pretrained('saltacc/anime-ai-detect') | |
self.model = BeitForImageClassification.from_pretrained('saltacc/anime-ai-detect').cuda() | |
print('加载完毕') | |
if args.target=='ai': #攻击成被识别为AI | |
self.target = torch.tensor([1]).to(device) | |
elif args.target=='human': | |
self.target = torch.tensor([0]).to(device) | |
dataset_mean_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).cuda() | |
dataset_std_t = torch.tensor([0.5, 0.5, 0.5]).view(1, -1, 1, 1).cuda() | |
self.pgd = PGD(self.model, img_transform=(lambda x: (x - dataset_mean_t) / dataset_std_t, lambda x: x * dataset_std_t + dataset_mean_t)) | |
self.pgd.set_para(eps=(args.eps * 2) / 255, alpha=lambda: (args.step_size * 2) / 255, iters=args.steps) | |
self.pgd.set_loss(CrossEntropyLoss()) | |
self.pgd.set_call_back(pgd_callback) | |
def save_image(self, image, noise, img_name): | |
# 缩放图片只缩放噪声 | |
W, H = image.size | |
noise = F.interpolate(noise, size=(H, W), mode='bicubic') | |
img_save = transforms.ToTensor()(image) + noise | |
save_image(img_save, os.path.join(self.args.out_dir, f'{img_name[:img_name.rfind(".")]}_atk.png')) | |
def attack_(self, image): | |
inputs = self.feature_extractor(images=image, return_tensors="pt")['pixel_values'].cuda() | |
if self.args.target == 'auto': | |
with torch.no_grad(): | |
outputs = self.model(inputs) | |
logits = outputs.logits | |
cls = logits.argmax(-1).item() | |
target = torch.tensor([cls]).to(device) | |
else: | |
target = self.target | |
if self.args.test_atk: | |
self.test_image(inputs, 'before attack') | |
atk_img = self.pgd.attack(inputs, target) | |
noise = self.pgd.img_transform[1](atk_img).detach().cpu() - self.pgd.img_transform[1](inputs).detach().cpu() | |
if self.args.test_atk: | |
self.test_image(atk_img, 'after attack') | |
return atk_img, noise | |
def attack_one(self, path): | |
image = Image.open(path).convert('RGB') | |
atk_img, noise = self.attack_(image) | |
self.save_image(image, noise, os.path.basename(path)) | |
def attack(self, path): | |
count=0 | |
if os.path.isdir(path): | |
img_list=[os.path.join(path, x) for x in os.listdir(path)] | |
for img in img_list: | |
if (img.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff'))): | |
self.attack_one(img) | |
count+=1 | |
else: | |
if (path.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff'))): | |
self.attack_one(path) | |
count += 1 | |
print(f'总共攻击{count}张图像') | |
def test_image(self, img, pre_fix): | |
outputs = self.model(img) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
print(pre_fix, "class:", self.model.config.id2label[predicted_class_idx], 'logits:', logits) | |
if __name__ == '__main__': | |
args=make_args() | |
attacker = Attacker(args) | |
attacker.attack(args.inputs) |