|
import math |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class GeM(nn.Module): |
|
def __init__(self, p=3, eps=1e-6, requires_grad=False): |
|
super().__init__() |
|
self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad) |
|
self.eps = eps |
|
|
|
def forward(self, x: torch.Tensor): |
|
return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p) |
|
|
|
|
|
|
|
|
|
class ArcMarginProductSubcenter(nn.Module): |
|
def __init__(self, in_features: int, out_features: int, k: int = 3): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features)) |
|
self.reset_parameters() |
|
self.k = k |
|
self.out_features = out_features |
|
|
|
def reset_parameters(self): |
|
stdv = 1.0 / math.sqrt(self.weight.size(1)) |
|
self.weight.data.uniform_(-stdv, stdv) |
|
|
|
def forward(self, features: torch.Tensor) -> torch.Tensor: |
|
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) |
|
cosine_all = cosine_all.view(-1, self.out_features, self.k) |
|
cosine, _ = torch.max(cosine_all, dim=2) |
|
return cosine |
|
|
|
|
|
class ArcFaceLossAdaptiveMargin(nn.modules.Module): |
|
def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0): |
|
super().__init__() |
|
self.s = s |
|
self.margins = margins |
|
self.out_dim = n_classes |
|
|
|
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
|
ms = self.margins[labels.cpu().numpy()] |
|
cos_m = torch.from_numpy(np.cos(ms)).float().cuda() |
|
sin_m = torch.from_numpy(np.sin(ms)).float().cuda() |
|
th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda() |
|
mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda() |
|
labels = F.one_hot(labels, self.out_dim).float() |
|
logits = logits.float() |
|
cosine = logits |
|
sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) |
|
phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1) |
|
phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1)) |
|
return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s |
|
|