Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
cetacean-classifier / metric_learning.py
MalloryWittwerEPFL's picture
Upload model
6257083 verified
raw
history blame
2.36 kB
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6, requires_grad=False):
super().__init__()
self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad)
self.eps = eps
def forward(self, x: torch.Tensor):
return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p)
# Copied and modified from
# https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py
class ArcMarginProductSubcenter(nn.Module):
def __init__(self, in_features: int, out_features: int, k: int = 3):
super().__init__()
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
self.reset_parameters()
self.k = k
self.out_features = out_features
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, features: torch.Tensor) -> torch.Tensor:
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
cosine_all = cosine_all.view(-1, self.out_features, self.k)
cosine, _ = torch.max(cosine_all, dim=2)
return cosine
class ArcFaceLossAdaptiveMargin(nn.modules.Module):
def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0):
super().__init__()
self.s = s
self.margins = margins
self.out_dim = n_classes
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
ms = self.margins[labels.cpu().numpy()]
cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
labels = F.one_hot(labels, self.out_dim).float()
logits = logits.float()
cosine = logits
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s