import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class GeM(nn.Module): def __init__(self, p=3, eps=1e-6, requires_grad=False): super().__init__() self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad) self.eps = eps def forward(self, x: torch.Tensor): return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p) # Copied and modified from # https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py class ArcMarginProductSubcenter(nn.Module): def __init__(self, in_features: int, out_features: int, k: int = 3): super().__init__() self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features)) self.reset_parameters() self.k = k self.out_features = out_features def reset_parameters(self): stdv = 1.0 / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) def forward(self, features: torch.Tensor) -> torch.Tensor: cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) cosine_all = cosine_all.view(-1, self.out_features, self.k) cosine, _ = torch.max(cosine_all, dim=2) return cosine class ArcFaceLossAdaptiveMargin(nn.modules.Module): def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0): super().__init__() self.s = s self.margins = margins self.out_dim = n_classes def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: ms = self.margins[labels.cpu().numpy()] cos_m = torch.from_numpy(np.cos(ms)).float().cuda() sin_m = torch.from_numpy(np.sin(ms)).float().cuda() th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda() mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda() labels = F.one_hot(labels, self.out_dim).float() logits = logits.float() cosine = logits sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1) phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1)) return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s