MalloryWittwerEPFL
commited on
Commit
•
d514464
1
Parent(s):
ea195e7
Upload model
Browse files- config.py +28 -0
- configuration_cetacean_classifier.py +12 -0
- metric_learning.py +59 -0
- model.safetensors +1 -1
- modeling_cetacean_classifier.py +92 -0
- train.py +150 -0
- utils.py +41 -0
config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
|
6 |
+
class Config(dict):
|
7 |
+
def __getattr__(self, key):
|
8 |
+
try:
|
9 |
+
val = self[key]
|
10 |
+
except KeyError:
|
11 |
+
return super().__getattr__(key)
|
12 |
+
if isinstance(val, dict):
|
13 |
+
return Config(val)
|
14 |
+
return val
|
15 |
+
|
16 |
+
|
17 |
+
def load_config(path: str, default_path: Optional[str]) -> Config:
|
18 |
+
with open(path) as f:
|
19 |
+
cfg = Config(yaml.full_load(f))
|
20 |
+
if default_path is not None:
|
21 |
+
# set keys not included in `path` by default
|
22 |
+
with open(default_path) as f:
|
23 |
+
default_cfg = Config(yaml.full_load(f))
|
24 |
+
for key, val in default_cfg.items():
|
25 |
+
if key not in cfg:
|
26 |
+
print(f"used default config {key}: {val}")
|
27 |
+
cfg[key] = val
|
28 |
+
return cfg
|
configuration_cetacean_classifier.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class CetaceanClassifierConfig(PretrainedConfig):
|
6 |
+
model_type = "cetaceanet"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
**kwargs,
|
11 |
+
):
|
12 |
+
super().__init__(**kwargs)
|
metric_learning.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class GeM(nn.Module):
|
10 |
+
def __init__(self, p=3, eps=1e-6, requires_grad=False):
|
11 |
+
super().__init__()
|
12 |
+
self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad)
|
13 |
+
self.eps = eps
|
14 |
+
|
15 |
+
def forward(self, x: torch.Tensor):
|
16 |
+
return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p)
|
17 |
+
|
18 |
+
|
19 |
+
# Copied and modified from
|
20 |
+
# https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py
|
21 |
+
class ArcMarginProductSubcenter(nn.Module):
|
22 |
+
def __init__(self, in_features: int, out_features: int, k: int = 3):
|
23 |
+
super().__init__()
|
24 |
+
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
|
25 |
+
self.reset_parameters()
|
26 |
+
self.k = k
|
27 |
+
self.out_features = out_features
|
28 |
+
|
29 |
+
def reset_parameters(self):
|
30 |
+
stdv = 1.0 / math.sqrt(self.weight.size(1))
|
31 |
+
self.weight.data.uniform_(-stdv, stdv)
|
32 |
+
|
33 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
34 |
+
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
|
35 |
+
cosine_all = cosine_all.view(-1, self.out_features, self.k)
|
36 |
+
cosine, _ = torch.max(cosine_all, dim=2)
|
37 |
+
return cosine
|
38 |
+
|
39 |
+
|
40 |
+
class ArcFaceLossAdaptiveMargin(nn.modules.Module):
|
41 |
+
def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0):
|
42 |
+
super().__init__()
|
43 |
+
self.s = s
|
44 |
+
self.margins = margins
|
45 |
+
self.out_dim = n_classes
|
46 |
+
|
47 |
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
48 |
+
ms = self.margins[labels.cpu().numpy()]
|
49 |
+
cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
|
50 |
+
sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
|
51 |
+
th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
|
52 |
+
mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
|
53 |
+
labels = F.one_hot(labels, self.out_dim).float()
|
54 |
+
logits = logits.float()
|
55 |
+
cosine = logits
|
56 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
57 |
+
phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
|
58 |
+
phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
|
59 |
+
return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 296028464
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:648dd257e82e5d02a1e649cfb3193554c096e40e520f903efe150d746ddd70fa
|
3 |
size 296028464
|
modeling_cetacean_classifier.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import albumentations as A
|
2 |
+
from transformers import PreTrainedModel
|
3 |
+
# from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from .train import SphereClassifier
|
9 |
+
from .configuration_cetacean_classifier import CetaceanClassifierConfig
|
10 |
+
|
11 |
+
|
12 |
+
WHALE_CLASSES = np.array(
|
13 |
+
[
|
14 |
+
"beluga",
|
15 |
+
"blue_whale",
|
16 |
+
"bottlenose_dolphin",
|
17 |
+
"brydes_whale",
|
18 |
+
"commersons_dolphin",
|
19 |
+
"common_dolphin",
|
20 |
+
"cuviers_beaked_whale",
|
21 |
+
"dusky_dolphin",
|
22 |
+
"false_killer_whale",
|
23 |
+
"fin_whale",
|
24 |
+
"frasiers_dolphin",
|
25 |
+
"gray_whale",
|
26 |
+
"humpback_whale",
|
27 |
+
"killer_whale",
|
28 |
+
"long_finned_pilot_whale",
|
29 |
+
"melon_headed_whale",
|
30 |
+
"minke_whale",
|
31 |
+
"pantropic_spotted_dolphin",
|
32 |
+
"pygmy_killer_whale",
|
33 |
+
"rough_toothed_dolphin",
|
34 |
+
"sei_whale",
|
35 |
+
"short_finned_pilot_whale",
|
36 |
+
"southern_right_whale",
|
37 |
+
"spinner_dolphin",
|
38 |
+
"spotted_dolphin",
|
39 |
+
"white_sided_dolphin",
|
40 |
+
]
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
class CetaceanClassifierModelForImageClassification(PreTrainedModel):
|
45 |
+
config_class = CetaceanClassifierConfig
|
46 |
+
|
47 |
+
def __init__(self, config):
|
48 |
+
super().__init__(config)
|
49 |
+
|
50 |
+
self.model = SphereClassifier(cfg=config.to_dict())
|
51 |
+
|
52 |
+
# load_from_checkpoint("cetacean_classifier/last.ckpt")
|
53 |
+
# self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
|
54 |
+
|
55 |
+
self.model.eval()
|
56 |
+
self.config = config
|
57 |
+
self.transforms = self.make_transforms(data_aug=True)
|
58 |
+
|
59 |
+
def make_transforms(self, data_aug: bool):
|
60 |
+
augments = []
|
61 |
+
if data_aug:
|
62 |
+
aug = self.config.aug
|
63 |
+
augments = [
|
64 |
+
A.RandomResizedCrop(
|
65 |
+
self.config.image_size[0],
|
66 |
+
self.config.image_size[1],
|
67 |
+
scale=(aug["crop_scale"], 1.0),
|
68 |
+
ratio=(aug["crop_l"], aug["crop_r"]),
|
69 |
+
),]
|
70 |
+
return A.Compose(augments)
|
71 |
+
|
72 |
+
def preprocess_image(self, img) -> torch.Tensor:
|
73 |
+
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
74 |
+
image = cv2.resize(rgb, self.config.image_size, interpolation=cv2.INTER_CUBIC)
|
75 |
+
image = self.transforms(image=image)["image"]
|
76 |
+
return torch.Tensor(image).transpose(2, 0).unsqueeze(0)
|
77 |
+
#image_resized = img.resize((480, 480))
|
78 |
+
#image_resized = np.array(image_resized)[None]
|
79 |
+
#image_resized = np.transpose(image_resized, [0, 3, 2, 1])
|
80 |
+
#image_tensor = torch.Tensor(image_resized)
|
81 |
+
#return image_tensor
|
82 |
+
|
83 |
+
def forward(self, img, labels=None):
|
84 |
+
tensor = self.preprocess_image(img)
|
85 |
+
head_id_logits, head_species_logits = self.model(tensor)
|
86 |
+
head_species_logits = head_species_logits.detach().numpy()
|
87 |
+
sorted_idx = head_species_logits.argsort()[0]
|
88 |
+
sorted_idx = np.array(list(reversed(sorted_idx)))
|
89 |
+
top_three_logits = sorted_idx[:3]
|
90 |
+
top_three_whale_preds = WHALE_CLASSES[top_three_logits]
|
91 |
+
|
92 |
+
return {"predictions": top_three_whale_preds}
|
train.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import timm
|
5 |
+
import torch
|
6 |
+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
7 |
+
|
8 |
+
from .config import Config, load_config
|
9 |
+
# from .dataset import WhaleDataset, load_df
|
10 |
+
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
|
11 |
+
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
|
12 |
+
|
13 |
+
|
14 |
+
class SphereClassifier(LightningModule):
|
15 |
+
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
|
16 |
+
super().__init__()
|
17 |
+
# import pdb; pdb.set_trace()
|
18 |
+
if not isinstance(cfg, Config):
|
19 |
+
cfg = Config(cfg)
|
20 |
+
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
|
21 |
+
self.test_results_fp = None
|
22 |
+
|
23 |
+
# import json
|
24 |
+
# cfg_json = json.dumps(cfg)
|
25 |
+
# with open("config_extracted.json", "w") as file:
|
26 |
+
# file.write(cfg_json)
|
27 |
+
|
28 |
+
# NN architecture
|
29 |
+
self.backbone = timm.create_model(
|
30 |
+
cfg.model_name,
|
31 |
+
in_chans=3,
|
32 |
+
pretrained=cfg.pretrained,
|
33 |
+
num_classes=0,
|
34 |
+
features_only=True,
|
35 |
+
out_indices=cfg.out_indices,
|
36 |
+
)
|
37 |
+
feature_dims = self.backbone.feature_info.channels()
|
38 |
+
print(f"feature dims: {feature_dims}")
|
39 |
+
self.global_pools = torch.nn.ModuleList(
|
40 |
+
[GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
|
41 |
+
)
|
42 |
+
self.mid_features = np.sum(feature_dims)
|
43 |
+
if cfg.normalization == "batchnorm":
|
44 |
+
self.neck = torch.nn.BatchNorm1d(self.mid_features)
|
45 |
+
elif cfg.normalization == "layernorm":
|
46 |
+
self.neck = torch.nn.LayerNorm(self.mid_features)
|
47 |
+
self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
|
48 |
+
self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
|
49 |
+
if id_class_nums is not None and species_class_nums is not None:
|
50 |
+
margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
|
51 |
+
margins_species = (
|
52 |
+
np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
|
53 |
+
+ cfg.margin_cons_species
|
54 |
+
)
|
55 |
+
print("margins_id", margins_id)
|
56 |
+
print("margins_species", margins_species)
|
57 |
+
self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
|
58 |
+
self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
|
59 |
+
self.loss_fn_id = torch.nn.CrossEntropyLoss()
|
60 |
+
self.loss_fn_species = torch.nn.CrossEntropyLoss()
|
61 |
+
|
62 |
+
def get_feat(self, x: torch.Tensor) -> torch.Tensor:
|
63 |
+
ms = self.backbone(x)
|
64 |
+
h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
|
65 |
+
return self.neck(h)
|
66 |
+
|
67 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
68 |
+
feat = self.get_feat(x)
|
69 |
+
return self.head_id(feat), self.head_species(feat)
|
70 |
+
|
71 |
+
def training_step(self, batch, batch_idx):
|
72 |
+
x, ids, species = batch["image"], batch["label"], batch["label_species"]
|
73 |
+
logits_ids, logits_species = self(x)
|
74 |
+
margin_logits_ids = self.margin_fn_id(logits_ids, ids)
|
75 |
+
loss_ids = self.loss_fn_id(margin_logits_ids, ids)
|
76 |
+
loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
|
77 |
+
self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
|
78 |
+
self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
|
79 |
+
with torch.no_grad():
|
80 |
+
self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
|
81 |
+
self.log_dict(
|
82 |
+
{"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
|
83 |
+
on_step=False,
|
84 |
+
on_epoch=True,
|
85 |
+
)
|
86 |
+
return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
|
87 |
+
|
88 |
+
def validation_step(self, batch, batch_idx):
|
89 |
+
x, ids, species = batch["image"], batch["label"], batch["label_species"]
|
90 |
+
out1, out_species1 = self(x)
|
91 |
+
out2, out_species2 = self(x.flip(3))
|
92 |
+
output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
|
93 |
+
self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
|
94 |
+
self.log_dict(
|
95 |
+
{"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
|
96 |
+
on_step=False,
|
97 |
+
on_epoch=True,
|
98 |
+
)
|
99 |
+
|
100 |
+
def configure_optimizers(self):
|
101 |
+
backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
|
102 |
+
head_params = (
|
103 |
+
list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
|
104 |
+
)
|
105 |
+
params = [
|
106 |
+
{"params": backbone_params, "lr": self.hparams.lr_backbone},
|
107 |
+
{"params": head_params, "lr": self.hparams.lr_head},
|
108 |
+
]
|
109 |
+
if self.hparams.optimizer == "Adam":
|
110 |
+
optimizer = torch.optim.Adam(params)
|
111 |
+
elif self.hparams.optimizer == "AdamW":
|
112 |
+
optimizer = torch.optim.AdamW(params)
|
113 |
+
elif self.hparams.optimizer == "RAdam":
|
114 |
+
optimizer = torch.optim.RAdam(params)
|
115 |
+
|
116 |
+
warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
|
117 |
+
cycle_steps = self.hparams.max_epochs - warmup_steps
|
118 |
+
lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
|
119 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
120 |
+
return [optimizer], [scheduler]
|
121 |
+
|
122 |
+
def test_step(self, batch, batch_idx):
|
123 |
+
x = batch["image"]
|
124 |
+
feat1 = self.get_feat(x)
|
125 |
+
out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
|
126 |
+
feat2 = self.get_feat(x.flip(3))
|
127 |
+
out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
|
128 |
+
pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
|
129 |
+
return {
|
130 |
+
"original_index": batch["original_index"],
|
131 |
+
"label": batch["label"],
|
132 |
+
"label_species": batch["label_species"],
|
133 |
+
"pred_logit": pred_logit[:, :1000],
|
134 |
+
"pred_idx": pred_idx[:, :1000],
|
135 |
+
"pred_species": ((out_species1 + out_species2) / 2).cpu(),
|
136 |
+
"embed_features1": feat1.cpu(),
|
137 |
+
"embed_features2": feat2.cpu(),
|
138 |
+
}
|
139 |
+
|
140 |
+
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
|
141 |
+
outputs = self.all_gather(outputs)
|
142 |
+
if self.trainer.global_rank == 0:
|
143 |
+
epoch_results: Dict[str, np.ndarray] = {}
|
144 |
+
for key in outputs[0].keys():
|
145 |
+
if torch.cuda.device_count() > 1:
|
146 |
+
result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
|
147 |
+
else:
|
148 |
+
result = torch.cat([x[key] for x in outputs], dim=0)
|
149 |
+
epoch_results[key] = result.detach().cpu().numpy()
|
150 |
+
np.savez_compressed(self.test_results_fp, **epoch_results)
|
utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class WarmupCosineLambda:
|
8 |
+
def __init__(self, warmup_steps: int, cycle_steps: int, decay_scale: float, exponential_warmup: bool = False):
|
9 |
+
self.warmup_steps = warmup_steps
|
10 |
+
self.cycle_steps = cycle_steps
|
11 |
+
self.decay_scale = decay_scale
|
12 |
+
self.exponential_warmup = exponential_warmup
|
13 |
+
|
14 |
+
def __call__(self, epoch: int):
|
15 |
+
if epoch < self.warmup_steps:
|
16 |
+
if self.exponential_warmup:
|
17 |
+
return self.decay_scale * pow(self.decay_scale, -epoch / self.warmup_steps)
|
18 |
+
ratio = epoch / self.warmup_steps
|
19 |
+
else:
|
20 |
+
ratio = (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.cycle_steps)) / 2
|
21 |
+
return self.decay_scale + (1 - self.decay_scale) * ratio
|
22 |
+
|
23 |
+
|
24 |
+
def topk_average_precision(output: torch.Tensor, y: torch.Tensor, k: int):
|
25 |
+
score_array = torch.tensor([1.0 / i for i in range(1, k + 1)], device=output.device)
|
26 |
+
topk = output.topk(k)[1]
|
27 |
+
match_mat = topk == y[:, None].expand(topk.shape)
|
28 |
+
return (match_mat * score_array).sum(dim=1)
|
29 |
+
|
30 |
+
|
31 |
+
def calc_map5(output: torch.Tensor, y: torch.Tensor, threshold: Optional[float]):
|
32 |
+
if threshold is not None:
|
33 |
+
output = torch.cat([output, torch.full((output.shape[0], 1), threshold, device=output.device)], dim=1)
|
34 |
+
return topk_average_precision(output, y, 5).mean().detach()
|
35 |
+
|
36 |
+
|
37 |
+
def map_dict(output: torch.Tensor, y: torch.Tensor, prefix: str):
|
38 |
+
d = {f"{prefix}/acc": topk_average_precision(output, y, 1).mean().detach()}
|
39 |
+
for threshold in [None, 0.3, 0.4, 0.5, 0.6, 0.7]:
|
40 |
+
d[f"{prefix}/map{threshold}"] = calc_map5(output, y, threshold)
|
41 |
+
return d
|