Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
MalloryWittwerEPFL commited on
Commit
6257083
·
verified ·
1 Parent(s): 9880624

Upload model

Browse files
config.json CHANGED
@@ -2,7 +2,11 @@
2
  "architectures": [
3
  "CetaceanClassifierModelForImageClassification"
4
  ],
5
- "model_type": "efficientnet",
 
 
 
 
6
  "torch_dtype": "float32",
7
  "transformers_version": "4.46.0"
8
  }
 
2
  "architectures": [
3
  "CetaceanClassifierModelForImageClassification"
4
  ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_cetacean_classifier.CetaceanClassifierConfig",
7
+ "AutoModelForImageClassification": "modeling_cetacean_classifier.CetaceanClassifierModelForImageClassification"
8
+ },
9
+ "model_type": "cetaceanet",
10
  "torch_dtype": "float32",
11
  "transformers_version": "4.46.0"
12
  }
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import yaml
4
+
5
+
6
+ class Config(dict):
7
+ def __getattr__(self, key):
8
+ try:
9
+ val = self[key]
10
+ except KeyError:
11
+ return super().__getattr__(key)
12
+ if isinstance(val, dict):
13
+ return Config(val)
14
+ return val
15
+
16
+
17
+ def load_config(path: str, default_path: Optional[str]) -> Config:
18
+ with open(path) as f:
19
+ cfg = Config(yaml.full_load(f))
20
+ if default_path is not None:
21
+ # set keys not included in `path` by default
22
+ with open(default_path) as f:
23
+ default_cfg = Config(yaml.full_load(f))
24
+ for key, val in default_cfg.items():
25
+ if key not in cfg:
26
+ print(f"used default config {key}: {val}")
27
+ cfg[key] = val
28
+ return cfg
configuration_cetacean_classifier.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class CetaceanClassifierConfig(PretrainedConfig):
6
+ model_type = "cetaceanet"
7
+
8
+ def __init__(
9
+ self,
10
+ # block_type="bottleneck",
11
+ # layers: List[int] = [3, 4, 6, 3],
12
+ # num_classes: int = 1000,
13
+ # input_channels: int = 3,
14
+ # cardinality: int = 1,
15
+ # base_width: int = 64,
16
+ # stem_width: int = 64,
17
+ # stem_type: str = "",
18
+ # avg_down: bool = False,
19
+ **kwargs,
20
+ ):
21
+ # if block_type not in ["basic", "bottleneck"]:
22
+ # raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
23
+ # if stem_type not in ["", "deep", "deep-tiered"]:
24
+ # raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
25
+
26
+ # self.block_type = block_type
27
+ # self.layers = layers
28
+ # self.num_classes = num_classes
29
+ # self.input_channels = input_channels
30
+ # self.cardinality = cardinality
31
+ # self.base_width = base_width
32
+ # self.stem_width = stem_width
33
+ # self.stem_type = stem_type
34
+ # self.avg_down = avg_down
35
+ super().__init__(**kwargs)
metric_learning.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class GeM(nn.Module):
10
+ def __init__(self, p=3, eps=1e-6, requires_grad=False):
11
+ super().__init__()
12
+ self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad)
13
+ self.eps = eps
14
+
15
+ def forward(self, x: torch.Tensor):
16
+ return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p)
17
+
18
+
19
+ # Copied and modified from
20
+ # https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py
21
+ class ArcMarginProductSubcenter(nn.Module):
22
+ def __init__(self, in_features: int, out_features: int, k: int = 3):
23
+ super().__init__()
24
+ self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
25
+ self.reset_parameters()
26
+ self.k = k
27
+ self.out_features = out_features
28
+
29
+ def reset_parameters(self):
30
+ stdv = 1.0 / math.sqrt(self.weight.size(1))
31
+ self.weight.data.uniform_(-stdv, stdv)
32
+
33
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
34
+ cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
35
+ cosine_all = cosine_all.view(-1, self.out_features, self.k)
36
+ cosine, _ = torch.max(cosine_all, dim=2)
37
+ return cosine
38
+
39
+
40
+ class ArcFaceLossAdaptiveMargin(nn.modules.Module):
41
+ def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0):
42
+ super().__init__()
43
+ self.s = s
44
+ self.margins = margins
45
+ self.out_dim = n_classes
46
+
47
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
48
+ ms = self.margins[labels.cpu().numpy()]
49
+ cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
50
+ sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
51
+ th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
52
+ mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
53
+ labels = F.one_hot(labels, self.out_dim).float()
54
+ logits = logits.float()
55
+ cosine = logits
56
+ sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
57
+ phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
58
+ phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
59
+ return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s
modeling_cetacean_classifier.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from PIL import Image
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .configuration_cetacean_classifier import CetaceanClassifierConfig
7
+ from .train import SphereClassifier
8
+
9
+
10
+ WHALE_CLASSES = np.array(
11
+ [
12
+ "beluga",
13
+ "blue_whale",
14
+ "bottlenose_dolphin",
15
+ "brydes_whale",
16
+ "commersons_dolphin",
17
+ "common_dolphin",
18
+ "cuviers_beaked_whale",
19
+ "dusky_dolphin",
20
+ "false_killer_whale",
21
+ "fin_whale",
22
+ "frasiers_dolphin",
23
+ "gray_whale",
24
+ "humpback_whale",
25
+ "killer_whale",
26
+ "long_finned_pilot_whale",
27
+ "melon_headed_whale",
28
+ "minke_whale",
29
+ "pantropic_spotted_dolphin",
30
+ "pygmy_killer_whale",
31
+ "rough_toothed_dolphin",
32
+ "sei_whale",
33
+ "short_finned_pilot_whale",
34
+ "southern_right_whale",
35
+ "spinner_dolphin",
36
+ "spotted_dolphin",
37
+ "white_sided_dolphin",
38
+ ]
39
+ )
40
+
41
+
42
+ class CetaceanClassifierModelForImageClassification(PreTrainedModel):
43
+ config_class = CetaceanClassifierConfig
44
+
45
+ def __init__(self, config):
46
+ super().__init__(config)
47
+ self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
48
+ self.model.eval()
49
+
50
+ def preprocess_image(self, img: Image) -> torch.Tensor:
51
+ image_resized = img.resize((480, 480))
52
+ image_resized = np.array(image_resized)[None]
53
+ image_resized = np.transpose(image_resized, [0, 3, 2, 1])
54
+ image_tensor = torch.Tensor(image_resized)
55
+ return image_tensor
56
+
57
+ def forward(self, img: Image, labels=None):
58
+ tensor = self.preprocess_image(img)
59
+ head_id_logits, head_species_logits = self.model(tensor)
60
+ head_species_logits = head_species_logits.detach().numpy()
61
+ sorted_idx = head_species_logits.argsort()[0]
62
+ sorted_idx = np.array(list(reversed(sorted_idx)))
63
+ top_three_logits = sorted_idx[:3]
64
+ top_three_whale_preds = WHALE_CLASSES[top_three_logits]
65
+
66
+ return {"predictions": top_three_whale_preds}
train.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import argparse
2
+ # import os
3
+ # import warnings
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ # import optuna
8
+ # import pandas as pd
9
+ import timm
10
+ import torch
11
+ # import wandb
12
+ # from optuna.integration import PyTorchLightningPruningCallback
13
+ from pytorch_lightning import LightningDataModule, LightningModule, Trainer
14
+ # from pytorch_lightning import loggers as pl_loggers
15
+ # from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
16
+ # from sklearn.model_selection import StratifiedKFold
17
+ # from torch.utils.data import ConcatDataset, DataLoader
18
+
19
+ from .config import Config, load_config
20
+ # from .dataset import WhaleDataset, load_df
21
+ from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
22
+ from .utils import WarmupCosineLambda, map_dict, topk_average_precision
23
+
24
+
25
+ # def parse():
26
+ # parser = argparse.ArgumentParser(description="Training for HappyWhale")
27
+ # parser.add_argument("--out_base_dir", default="result")
28
+ # parser.add_argument("--in_base_dir", default="input")
29
+ # parser.add_argument("--exp_name", default="tmp")
30
+ # parser.add_argument("--load_snapshot", action="store_true")
31
+ # parser.add_argument("--save_checkpoint", action="store_true")
32
+ # parser.add_argument("--wandb_logger", action="store_true")
33
+ # parser.add_argument("--config_path", default="config/debug.yaml")
34
+ # return parser.parse_args()
35
+
36
+
37
+ # class WhaleDataModule(LightningDataModule):
38
+ # def __init__(
39
+ # self,
40
+ # df: pd.DataFrame,
41
+ # cfg: Config,
42
+ # image_dir: str,
43
+ # val_bbox_name: str,
44
+ # fold: int,
45
+ # additional_dataset: WhaleDataset = None,
46
+ # ):
47
+ # super().__init__()
48
+ # self.cfg = cfg
49
+ # self.image_dir = image_dir
50
+ # self.val_bbox_name = val_bbox_name
51
+ # self.additional_dataset = additional_dataset
52
+ # if cfg.n_data != -1:
53
+ # df = df.iloc[: cfg.n_data]
54
+ # self.all_df = df
55
+ # if fold == -1:
56
+ # self.train_df = df
57
+ # else:
58
+ # skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0)
59
+ # train_idx, val_idx = list(skf.split(df, df.individual_id))[fold]
60
+ # self.train_df = df.iloc[train_idx].copy()
61
+ # self.val_df = df.iloc[val_idx].copy()
62
+ # # relabel ids not included in training data as "new individual"
63
+ # new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id)
64
+ # self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True)
65
+ # print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}")
66
+
67
+ # def get_dataset(self, df, data_aug):
68
+ # return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug)
69
+
70
+ # def train_dataloader(self):
71
+ # dataset = self.get_dataset(self.train_df, True)
72
+ # if self.additional_dataset is not None:
73
+ # dataset = ConcatDataset([dataset, self.additional_dataset])
74
+ # return DataLoader(
75
+ # dataset,
76
+ # batch_size=self.cfg.batch_size,
77
+ # shuffle=True,
78
+ # num_workers=2,
79
+ # pin_memory=True,
80
+ # drop_last=True,
81
+ # )
82
+
83
+ # def val_dataloader(self):
84
+ # if self.cfg.n_splits == -1:
85
+ # return None
86
+ # return DataLoader(
87
+ # self.get_dataset(self.val_df, False),
88
+ # batch_size=self.cfg.batch_size,
89
+ # shuffle=False,
90
+ # num_workers=2,
91
+ # pin_memory=True,
92
+ # )
93
+
94
+ # def all_dataloader(self):
95
+ # return DataLoader(
96
+ # self.get_dataset(self.all_df, False),
97
+ # batch_size=self.cfg.batch_size,
98
+ # shuffle=False,
99
+ # num_workers=2,
100
+ # pin_memory=True,
101
+ # )
102
+
103
+
104
+ class SphereClassifier(LightningModule):
105
+ def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
106
+ super().__init__()
107
+ if not isinstance(cfg, Config):
108
+ cfg = Config(cfg)
109
+ self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
110
+ self.test_results_fp = None
111
+
112
+ print(cfg.model_name)
113
+
114
+ # NN architecture
115
+ self.backbone = timm.create_model(
116
+ cfg.model_name,
117
+ in_chans=3,
118
+ pretrained=cfg.pretrained,
119
+ num_classes=0,
120
+ features_only=True,
121
+ out_indices=cfg.out_indices,
122
+ )
123
+ feature_dims = self.backbone.feature_info.channels()
124
+ print(f"feature dims: {feature_dims}")
125
+ self.global_pools = torch.nn.ModuleList(
126
+ [GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
127
+ )
128
+ self.mid_features = np.sum(feature_dims)
129
+ if cfg.normalization == "batchnorm":
130
+ self.neck = torch.nn.BatchNorm1d(self.mid_features)
131
+ elif cfg.normalization == "layernorm":
132
+ self.neck = torch.nn.LayerNorm(self.mid_features)
133
+ self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
134
+ self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
135
+ if id_class_nums is not None and species_class_nums is not None:
136
+ margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
137
+ margins_species = (
138
+ np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
139
+ + cfg.margin_cons_species
140
+ )
141
+ print("margins_id", margins_id)
142
+ print("margins_species", margins_species)
143
+ self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
144
+ self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
145
+ self.loss_fn_id = torch.nn.CrossEntropyLoss()
146
+ self.loss_fn_species = torch.nn.CrossEntropyLoss()
147
+
148
+ def get_feat(self, x: torch.Tensor) -> torch.Tensor:
149
+ ms = self.backbone(x)
150
+ h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
151
+ return self.neck(h)
152
+
153
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
154
+ feat = self.get_feat(x)
155
+ return self.head_id(feat), self.head_species(feat)
156
+
157
+ def training_step(self, batch, batch_idx):
158
+ x, ids, species = batch["image"], batch["label"], batch["label_species"]
159
+ logits_ids, logits_species = self(x)
160
+ margin_logits_ids = self.margin_fn_id(logits_ids, ids)
161
+ loss_ids = self.loss_fn_id(margin_logits_ids, ids)
162
+ loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
163
+ self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
164
+ self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
165
+ with torch.no_grad():
166
+ self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
167
+ self.log_dict(
168
+ {"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
169
+ on_step=False,
170
+ on_epoch=True,
171
+ )
172
+ return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
173
+
174
+ def validation_step(self, batch, batch_idx):
175
+ x, ids, species = batch["image"], batch["label"], batch["label_species"]
176
+ out1, out_species1 = self(x)
177
+ out2, out_species2 = self(x.flip(3))
178
+ output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
179
+ self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
180
+ self.log_dict(
181
+ {"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
182
+ on_step=False,
183
+ on_epoch=True,
184
+ )
185
+
186
+ def configure_optimizers(self):
187
+ backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
188
+ head_params = (
189
+ list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
190
+ )
191
+ params = [
192
+ {"params": backbone_params, "lr": self.hparams.lr_backbone},
193
+ {"params": head_params, "lr": self.hparams.lr_head},
194
+ ]
195
+ if self.hparams.optimizer == "Adam":
196
+ optimizer = torch.optim.Adam(params)
197
+ elif self.hparams.optimizer == "AdamW":
198
+ optimizer = torch.optim.AdamW(params)
199
+ elif self.hparams.optimizer == "RAdam":
200
+ optimizer = torch.optim.RAdam(params)
201
+
202
+ warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
203
+ cycle_steps = self.hparams.max_epochs - warmup_steps
204
+ lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
205
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
206
+ return [optimizer], [scheduler]
207
+
208
+ def test_step(self, batch, batch_idx):
209
+ x = batch["image"]
210
+ feat1 = self.get_feat(x)
211
+ out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
212
+ feat2 = self.get_feat(x.flip(3))
213
+ out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
214
+ pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
215
+ return {
216
+ "original_index": batch["original_index"],
217
+ "label": batch["label"],
218
+ "label_species": batch["label_species"],
219
+ "pred_logit": pred_logit[:, :1000],
220
+ "pred_idx": pred_idx[:, :1000],
221
+ "pred_species": ((out_species1 + out_species2) / 2).cpu(),
222
+ "embed_features1": feat1.cpu(),
223
+ "embed_features2": feat2.cpu(),
224
+ }
225
+
226
+ def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
227
+ outputs = self.all_gather(outputs)
228
+ if self.trainer.global_rank == 0:
229
+ epoch_results: Dict[str, np.ndarray] = {}
230
+ for key in outputs[0].keys():
231
+ if torch.cuda.device_count() > 1:
232
+ result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
233
+ else:
234
+ result = torch.cat([x[key] for x in outputs], dim=0)
235
+ epoch_results[key] = result.detach().cpu().numpy()
236
+ np.savez_compressed(self.test_results_fp, **epoch_results)
237
+
238
+
239
+ # def train(
240
+ # df: pd.DataFrame,
241
+ # args: argparse.Namespace,
242
+ # cfg: Config,
243
+ # fold: int,
244
+ # do_inference: bool = False,
245
+ # additional_dataset: WhaleDataset = None,
246
+ # optuna_trial: Optional[optuna.Trial] = None,
247
+ # ) -> Optional[float]:
248
+ # out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}"
249
+ # id_class_nums = df.individual_id.value_counts().sort_index().values
250
+ # species_class_nums = df.species.value_counts().sort_index().values
251
+ # model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums)
252
+ # data_module = WhaleDataModule(
253
+ # df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset
254
+ # )
255
+ # loggers = [pl_loggers.CSVLogger(out_dir)]
256
+ # if args.wandb_logger:
257
+ # loggers.append(
258
+ # pl_loggers.WandbLogger(
259
+ # project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir
260
+ # )
261
+ # )
262
+ # callbacks = [LearningRateMonitor("epoch")]
263
+ # if optuna_trial is not None:
264
+ # callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone"))
265
+ # if args.save_checkpoint:
266
+ # callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0))
267
+ # trainer = Trainer(
268
+ # gpus=torch.cuda.device_count(),
269
+ # max_epochs=cfg["max_epochs"],
270
+ # logger=loggers,
271
+ # callbacks=callbacks,
272
+ # checkpoint_callback=args.save_checkpoint,
273
+ # precision=16,
274
+ # sync_batchnorm=True,
275
+ # )
276
+ # ckpt_path = f"{out_dir}/last.ckpt"
277
+ # if not os.path.exists(ckpt_path) or not args.load_snapshot:
278
+ # ckpt_path = None
279
+ # trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module)
280
+ # if do_inference:
281
+ # for test_bbox in cfg.test_bboxes:
282
+ # # all train data
283
+ # model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz"
284
+ # trainer.test(model, data_module.all_dataloader())
285
+ # # test data
286
+ # model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz"
287
+ # df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False)
288
+ # test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1)
289
+ # trainer.test(model, test_data_module.all_dataloader())
290
+
291
+ # if args.wandb_logger:
292
+ # wandb.finish()
293
+ # if optuna_trial is not None:
294
+ # return trainer.callback_metrics["val/mapNone"].item()
295
+ # else:
296
+ # return None
297
+
298
+
299
+ # def main():
300
+ # args = parse()
301
+ # warnings.filterwarnings("ignore", ".*does not have many workers.*")
302
+ # cfg = load_config(args.config_path, "config/default.yaml")
303
+ # print(cfg)
304
+ # df = load_df(args.in_base_dir, cfg, "train.csv", True)
305
+ # pseudo_dataset = None
306
+ # if cfg.pseudo_label is not None:
307
+ # pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False)
308
+ # pseudo_dataset = WhaleDataset(
309
+ # pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True
310
+ # )
311
+ # if cfg["n_splits"] == -1:
312
+ # train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset)
313
+ # else:
314
+ # train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset)
315
+
316
+
317
+ # if __name__ == "__main__":
318
+ # main()
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+
7
+ class WarmupCosineLambda:
8
+ def __init__(self, warmup_steps: int, cycle_steps: int, decay_scale: float, exponential_warmup: bool = False):
9
+ self.warmup_steps = warmup_steps
10
+ self.cycle_steps = cycle_steps
11
+ self.decay_scale = decay_scale
12
+ self.exponential_warmup = exponential_warmup
13
+
14
+ def __call__(self, epoch: int):
15
+ if epoch < self.warmup_steps:
16
+ if self.exponential_warmup:
17
+ return self.decay_scale * pow(self.decay_scale, -epoch / self.warmup_steps)
18
+ ratio = epoch / self.warmup_steps
19
+ else:
20
+ ratio = (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.cycle_steps)) / 2
21
+ return self.decay_scale + (1 - self.decay_scale) * ratio
22
+
23
+
24
+ def topk_average_precision(output: torch.Tensor, y: torch.Tensor, k: int):
25
+ score_array = torch.tensor([1.0 / i for i in range(1, k + 1)], device=output.device)
26
+ topk = output.topk(k)[1]
27
+ match_mat = topk == y[:, None].expand(topk.shape)
28
+ return (match_mat * score_array).sum(dim=1)
29
+
30
+
31
+ def calc_map5(output: torch.Tensor, y: torch.Tensor, threshold: Optional[float]):
32
+ if threshold is not None:
33
+ output = torch.cat([output, torch.full((output.shape[0], 1), threshold, device=output.device)], dim=1)
34
+ return topk_average_precision(output, y, 5).mean().detach()
35
+
36
+
37
+ def map_dict(output: torch.Tensor, y: torch.Tensor, prefix: str):
38
+ d = {f"{prefix}/acc": topk_average_precision(output, y, 1).mean().detach()}
39
+ for threshold in [None, 0.3, 0.4, 0.5, 0.6, 0.7]:
40
+ d[f"{prefix}/map{threshold}"] = calc_map5(output, y, threshold)
41
+ return d