import torch class FGM(object): """ refer to the paper: FGM(Fast Gradient Method) Adversarial training methods for semi-supervised text classification """ def __init__(self, model): self.model = model self.backup = {} def attack(self, epsilon=1e-6, emd_name="embedding"): for name, param in self.model.named_parameters(): if param.requires_grad and emd_name in name: self.backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0 and not torch.isnan(norm): r_at = epsilon * param.grad / norm param.data.add_(r_at) def restore(self, emd_name="embedding"): for name, param in self.model.named_parameters(): if param.requires_grad and emd_name in name: assert name in self.backup param.data = self.backup[name] self.backup = {} class PGD(object): """ refer to the paper: PGD(Projected Gradient Descent) Towards Deep Learning Models Resistant to Adversarial Attacks """ def __init__(self, model): self.model = model self.emb_backup = {} self.grad_backup = {} def attack(self, epsilon=1., alpha=0.3, emb_name="embedding", is_first_attack=False): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: if is_first_attack: self.emb_backup[name] = param.data.clone() norm = torch.norm(param.grad) if norm != 0 and not torch.isnan(norm): r_at = alpha * param.grad / norm param.data.add_(r_at) param.data = self.project(name, param.data, epsilon) def restore(self, emb_name="embedding"): for name, param in self.model.named_parameters(): if param.requires_grad and emb_name in name: assert name in self.emb_backup param.data = self.emb_backup[name] self.emb_backup = {} def project(self, param_name, param_data, epsilon): r = param_data - self.emb_backup[param_name] if torch.norm(r) > epsilon: r = epsilon * r / torch.norm(r) return self.emb_backup[param_name] + r def backup_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad: self.grad_backup[name] = param.grad.clone() def restore_grad(self): for name, param in self.model.named_parameters(): if param.requires_grad: param.grad = self.grad_backup[name]