Upload model
Browse files- config.json +5 -1
- config.py +28 -0
- configuration_cetacean_classifier.py +35 -0
- metric_learning.py +59 -0
- modeling_cetacean_classifier.py +66 -0
- train.py +318 -0
- utils.py +41 -0
config.json
CHANGED
@@ -2,7 +2,11 @@
|
|
2 |
"architectures": [
|
3 |
"CetaceanClassifierModelForImageClassification"
|
4 |
],
|
5 |
-
"
|
|
|
|
|
|
|
|
|
6 |
"torch_dtype": "float32",
|
7 |
"transformers_version": "4.46.0"
|
8 |
}
|
|
|
2 |
"architectures": [
|
3 |
"CetaceanClassifierModelForImageClassification"
|
4 |
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_cetacean_classifier.CetaceanClassifierConfig",
|
7 |
+
"AutoModelForImageClassification": "modeling_cetacean_classifier.CetaceanClassifierModelForImageClassification"
|
8 |
+
},
|
9 |
+
"model_type": "cetaceanet",
|
10 |
"torch_dtype": "float32",
|
11 |
"transformers_version": "4.46.0"
|
12 |
}
|
config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
|
6 |
+
class Config(dict):
|
7 |
+
def __getattr__(self, key):
|
8 |
+
try:
|
9 |
+
val = self[key]
|
10 |
+
except KeyError:
|
11 |
+
return super().__getattr__(key)
|
12 |
+
if isinstance(val, dict):
|
13 |
+
return Config(val)
|
14 |
+
return val
|
15 |
+
|
16 |
+
|
17 |
+
def load_config(path: str, default_path: Optional[str]) -> Config:
|
18 |
+
with open(path) as f:
|
19 |
+
cfg = Config(yaml.full_load(f))
|
20 |
+
if default_path is not None:
|
21 |
+
# set keys not included in `path` by default
|
22 |
+
with open(default_path) as f:
|
23 |
+
default_cfg = Config(yaml.full_load(f))
|
24 |
+
for key, val in default_cfg.items():
|
25 |
+
if key not in cfg:
|
26 |
+
print(f"used default config {key}: {val}")
|
27 |
+
cfg[key] = val
|
28 |
+
return cfg
|
configuration_cetacean_classifier.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
class CetaceanClassifierConfig(PretrainedConfig):
|
6 |
+
model_type = "cetaceanet"
|
7 |
+
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
# block_type="bottleneck",
|
11 |
+
# layers: List[int] = [3, 4, 6, 3],
|
12 |
+
# num_classes: int = 1000,
|
13 |
+
# input_channels: int = 3,
|
14 |
+
# cardinality: int = 1,
|
15 |
+
# base_width: int = 64,
|
16 |
+
# stem_width: int = 64,
|
17 |
+
# stem_type: str = "",
|
18 |
+
# avg_down: bool = False,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
# if block_type not in ["basic", "bottleneck"]:
|
22 |
+
# raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.")
|
23 |
+
# if stem_type not in ["", "deep", "deep-tiered"]:
|
24 |
+
# raise ValueError(f"`stem_type` must be '', 'deep' or 'deep-tiered', got {stem_type}.")
|
25 |
+
|
26 |
+
# self.block_type = block_type
|
27 |
+
# self.layers = layers
|
28 |
+
# self.num_classes = num_classes
|
29 |
+
# self.input_channels = input_channels
|
30 |
+
# self.cardinality = cardinality
|
31 |
+
# self.base_width = base_width
|
32 |
+
# self.stem_width = stem_width
|
33 |
+
# self.stem_type = stem_type
|
34 |
+
# self.avg_down = avg_down
|
35 |
+
super().__init__(**kwargs)
|
metric_learning.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class GeM(nn.Module):
|
10 |
+
def __init__(self, p=3, eps=1e-6, requires_grad=False):
|
11 |
+
super().__init__()
|
12 |
+
self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad)
|
13 |
+
self.eps = eps
|
14 |
+
|
15 |
+
def forward(self, x: torch.Tensor):
|
16 |
+
return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p)
|
17 |
+
|
18 |
+
|
19 |
+
# Copied and modified from
|
20 |
+
# https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py
|
21 |
+
class ArcMarginProductSubcenter(nn.Module):
|
22 |
+
def __init__(self, in_features: int, out_features: int, k: int = 3):
|
23 |
+
super().__init__()
|
24 |
+
self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
|
25 |
+
self.reset_parameters()
|
26 |
+
self.k = k
|
27 |
+
self.out_features = out_features
|
28 |
+
|
29 |
+
def reset_parameters(self):
|
30 |
+
stdv = 1.0 / math.sqrt(self.weight.size(1))
|
31 |
+
self.weight.data.uniform_(-stdv, stdv)
|
32 |
+
|
33 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
34 |
+
cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
|
35 |
+
cosine_all = cosine_all.view(-1, self.out_features, self.k)
|
36 |
+
cosine, _ = torch.max(cosine_all, dim=2)
|
37 |
+
return cosine
|
38 |
+
|
39 |
+
|
40 |
+
class ArcFaceLossAdaptiveMargin(nn.modules.Module):
|
41 |
+
def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0):
|
42 |
+
super().__init__()
|
43 |
+
self.s = s
|
44 |
+
self.margins = margins
|
45 |
+
self.out_dim = n_classes
|
46 |
+
|
47 |
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
48 |
+
ms = self.margins[labels.cpu().numpy()]
|
49 |
+
cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
|
50 |
+
sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
|
51 |
+
th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
|
52 |
+
mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
|
53 |
+
labels = F.one_hot(labels, self.out_dim).float()
|
54 |
+
logits = logits.float()
|
55 |
+
cosine = logits
|
56 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
57 |
+
phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
|
58 |
+
phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
|
59 |
+
return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s
|
modeling_cetacean_classifier.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .configuration_cetacean_classifier import CetaceanClassifierConfig
|
7 |
+
from .train import SphereClassifier
|
8 |
+
|
9 |
+
|
10 |
+
WHALE_CLASSES = np.array(
|
11 |
+
[
|
12 |
+
"beluga",
|
13 |
+
"blue_whale",
|
14 |
+
"bottlenose_dolphin",
|
15 |
+
"brydes_whale",
|
16 |
+
"commersons_dolphin",
|
17 |
+
"common_dolphin",
|
18 |
+
"cuviers_beaked_whale",
|
19 |
+
"dusky_dolphin",
|
20 |
+
"false_killer_whale",
|
21 |
+
"fin_whale",
|
22 |
+
"frasiers_dolphin",
|
23 |
+
"gray_whale",
|
24 |
+
"humpback_whale",
|
25 |
+
"killer_whale",
|
26 |
+
"long_finned_pilot_whale",
|
27 |
+
"melon_headed_whale",
|
28 |
+
"minke_whale",
|
29 |
+
"pantropic_spotted_dolphin",
|
30 |
+
"pygmy_killer_whale",
|
31 |
+
"rough_toothed_dolphin",
|
32 |
+
"sei_whale",
|
33 |
+
"short_finned_pilot_whale",
|
34 |
+
"southern_right_whale",
|
35 |
+
"spinner_dolphin",
|
36 |
+
"spotted_dolphin",
|
37 |
+
"white_sided_dolphin",
|
38 |
+
]
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
class CetaceanClassifierModelForImageClassification(PreTrainedModel):
|
43 |
+
config_class = CetaceanClassifierConfig
|
44 |
+
|
45 |
+
def __init__(self, config):
|
46 |
+
super().__init__(config)
|
47 |
+
self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
|
48 |
+
self.model.eval()
|
49 |
+
|
50 |
+
def preprocess_image(self, img: Image) -> torch.Tensor:
|
51 |
+
image_resized = img.resize((480, 480))
|
52 |
+
image_resized = np.array(image_resized)[None]
|
53 |
+
image_resized = np.transpose(image_resized, [0, 3, 2, 1])
|
54 |
+
image_tensor = torch.Tensor(image_resized)
|
55 |
+
return image_tensor
|
56 |
+
|
57 |
+
def forward(self, img: Image, labels=None):
|
58 |
+
tensor = self.preprocess_image(img)
|
59 |
+
head_id_logits, head_species_logits = self.model(tensor)
|
60 |
+
head_species_logits = head_species_logits.detach().numpy()
|
61 |
+
sorted_idx = head_species_logits.argsort()[0]
|
62 |
+
sorted_idx = np.array(list(reversed(sorted_idx)))
|
63 |
+
top_three_logits = sorted_idx[:3]
|
64 |
+
top_three_whale_preds = WHALE_CLASSES[top_three_logits]
|
65 |
+
|
66 |
+
return {"predictions": top_three_whale_preds}
|
train.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import argparse
|
2 |
+
# import os
|
3 |
+
# import warnings
|
4 |
+
from typing import Dict, List, Optional, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
# import optuna
|
8 |
+
# import pandas as pd
|
9 |
+
import timm
|
10 |
+
import torch
|
11 |
+
# import wandb
|
12 |
+
# from optuna.integration import PyTorchLightningPruningCallback
|
13 |
+
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
14 |
+
# from pytorch_lightning import loggers as pl_loggers
|
15 |
+
# from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
16 |
+
# from sklearn.model_selection import StratifiedKFold
|
17 |
+
# from torch.utils.data import ConcatDataset, DataLoader
|
18 |
+
|
19 |
+
from .config import Config, load_config
|
20 |
+
# from .dataset import WhaleDataset, load_df
|
21 |
+
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
|
22 |
+
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
|
23 |
+
|
24 |
+
|
25 |
+
# def parse():
|
26 |
+
# parser = argparse.ArgumentParser(description="Training for HappyWhale")
|
27 |
+
# parser.add_argument("--out_base_dir", default="result")
|
28 |
+
# parser.add_argument("--in_base_dir", default="input")
|
29 |
+
# parser.add_argument("--exp_name", default="tmp")
|
30 |
+
# parser.add_argument("--load_snapshot", action="store_true")
|
31 |
+
# parser.add_argument("--save_checkpoint", action="store_true")
|
32 |
+
# parser.add_argument("--wandb_logger", action="store_true")
|
33 |
+
# parser.add_argument("--config_path", default="config/debug.yaml")
|
34 |
+
# return parser.parse_args()
|
35 |
+
|
36 |
+
|
37 |
+
# class WhaleDataModule(LightningDataModule):
|
38 |
+
# def __init__(
|
39 |
+
# self,
|
40 |
+
# df: pd.DataFrame,
|
41 |
+
# cfg: Config,
|
42 |
+
# image_dir: str,
|
43 |
+
# val_bbox_name: str,
|
44 |
+
# fold: int,
|
45 |
+
# additional_dataset: WhaleDataset = None,
|
46 |
+
# ):
|
47 |
+
# super().__init__()
|
48 |
+
# self.cfg = cfg
|
49 |
+
# self.image_dir = image_dir
|
50 |
+
# self.val_bbox_name = val_bbox_name
|
51 |
+
# self.additional_dataset = additional_dataset
|
52 |
+
# if cfg.n_data != -1:
|
53 |
+
# df = df.iloc[: cfg.n_data]
|
54 |
+
# self.all_df = df
|
55 |
+
# if fold == -1:
|
56 |
+
# self.train_df = df
|
57 |
+
# else:
|
58 |
+
# skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0)
|
59 |
+
# train_idx, val_idx = list(skf.split(df, df.individual_id))[fold]
|
60 |
+
# self.train_df = df.iloc[train_idx].copy()
|
61 |
+
# self.val_df = df.iloc[val_idx].copy()
|
62 |
+
# # relabel ids not included in training data as "new individual"
|
63 |
+
# new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id)
|
64 |
+
# self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True)
|
65 |
+
# print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}")
|
66 |
+
|
67 |
+
# def get_dataset(self, df, data_aug):
|
68 |
+
# return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug)
|
69 |
+
|
70 |
+
# def train_dataloader(self):
|
71 |
+
# dataset = self.get_dataset(self.train_df, True)
|
72 |
+
# if self.additional_dataset is not None:
|
73 |
+
# dataset = ConcatDataset([dataset, self.additional_dataset])
|
74 |
+
# return DataLoader(
|
75 |
+
# dataset,
|
76 |
+
# batch_size=self.cfg.batch_size,
|
77 |
+
# shuffle=True,
|
78 |
+
# num_workers=2,
|
79 |
+
# pin_memory=True,
|
80 |
+
# drop_last=True,
|
81 |
+
# )
|
82 |
+
|
83 |
+
# def val_dataloader(self):
|
84 |
+
# if self.cfg.n_splits == -1:
|
85 |
+
# return None
|
86 |
+
# return DataLoader(
|
87 |
+
# self.get_dataset(self.val_df, False),
|
88 |
+
# batch_size=self.cfg.batch_size,
|
89 |
+
# shuffle=False,
|
90 |
+
# num_workers=2,
|
91 |
+
# pin_memory=True,
|
92 |
+
# )
|
93 |
+
|
94 |
+
# def all_dataloader(self):
|
95 |
+
# return DataLoader(
|
96 |
+
# self.get_dataset(self.all_df, False),
|
97 |
+
# batch_size=self.cfg.batch_size,
|
98 |
+
# shuffle=False,
|
99 |
+
# num_workers=2,
|
100 |
+
# pin_memory=True,
|
101 |
+
# )
|
102 |
+
|
103 |
+
|
104 |
+
class SphereClassifier(LightningModule):
|
105 |
+
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
|
106 |
+
super().__init__()
|
107 |
+
if not isinstance(cfg, Config):
|
108 |
+
cfg = Config(cfg)
|
109 |
+
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
|
110 |
+
self.test_results_fp = None
|
111 |
+
|
112 |
+
print(cfg.model_name)
|
113 |
+
|
114 |
+
# NN architecture
|
115 |
+
self.backbone = timm.create_model(
|
116 |
+
cfg.model_name,
|
117 |
+
in_chans=3,
|
118 |
+
pretrained=cfg.pretrained,
|
119 |
+
num_classes=0,
|
120 |
+
features_only=True,
|
121 |
+
out_indices=cfg.out_indices,
|
122 |
+
)
|
123 |
+
feature_dims = self.backbone.feature_info.channels()
|
124 |
+
print(f"feature dims: {feature_dims}")
|
125 |
+
self.global_pools = torch.nn.ModuleList(
|
126 |
+
[GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
|
127 |
+
)
|
128 |
+
self.mid_features = np.sum(feature_dims)
|
129 |
+
if cfg.normalization == "batchnorm":
|
130 |
+
self.neck = torch.nn.BatchNorm1d(self.mid_features)
|
131 |
+
elif cfg.normalization == "layernorm":
|
132 |
+
self.neck = torch.nn.LayerNorm(self.mid_features)
|
133 |
+
self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
|
134 |
+
self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
|
135 |
+
if id_class_nums is not None and species_class_nums is not None:
|
136 |
+
margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
|
137 |
+
margins_species = (
|
138 |
+
np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
|
139 |
+
+ cfg.margin_cons_species
|
140 |
+
)
|
141 |
+
print("margins_id", margins_id)
|
142 |
+
print("margins_species", margins_species)
|
143 |
+
self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
|
144 |
+
self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
|
145 |
+
self.loss_fn_id = torch.nn.CrossEntropyLoss()
|
146 |
+
self.loss_fn_species = torch.nn.CrossEntropyLoss()
|
147 |
+
|
148 |
+
def get_feat(self, x: torch.Tensor) -> torch.Tensor:
|
149 |
+
ms = self.backbone(x)
|
150 |
+
h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
|
151 |
+
return self.neck(h)
|
152 |
+
|
153 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
154 |
+
feat = self.get_feat(x)
|
155 |
+
return self.head_id(feat), self.head_species(feat)
|
156 |
+
|
157 |
+
def training_step(self, batch, batch_idx):
|
158 |
+
x, ids, species = batch["image"], batch["label"], batch["label_species"]
|
159 |
+
logits_ids, logits_species = self(x)
|
160 |
+
margin_logits_ids = self.margin_fn_id(logits_ids, ids)
|
161 |
+
loss_ids = self.loss_fn_id(margin_logits_ids, ids)
|
162 |
+
loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
|
163 |
+
self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
|
164 |
+
self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
|
165 |
+
with torch.no_grad():
|
166 |
+
self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
|
167 |
+
self.log_dict(
|
168 |
+
{"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
|
169 |
+
on_step=False,
|
170 |
+
on_epoch=True,
|
171 |
+
)
|
172 |
+
return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
|
173 |
+
|
174 |
+
def validation_step(self, batch, batch_idx):
|
175 |
+
x, ids, species = batch["image"], batch["label"], batch["label_species"]
|
176 |
+
out1, out_species1 = self(x)
|
177 |
+
out2, out_species2 = self(x.flip(3))
|
178 |
+
output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
|
179 |
+
self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
|
180 |
+
self.log_dict(
|
181 |
+
{"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
|
182 |
+
on_step=False,
|
183 |
+
on_epoch=True,
|
184 |
+
)
|
185 |
+
|
186 |
+
def configure_optimizers(self):
|
187 |
+
backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
|
188 |
+
head_params = (
|
189 |
+
list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
|
190 |
+
)
|
191 |
+
params = [
|
192 |
+
{"params": backbone_params, "lr": self.hparams.lr_backbone},
|
193 |
+
{"params": head_params, "lr": self.hparams.lr_head},
|
194 |
+
]
|
195 |
+
if self.hparams.optimizer == "Adam":
|
196 |
+
optimizer = torch.optim.Adam(params)
|
197 |
+
elif self.hparams.optimizer == "AdamW":
|
198 |
+
optimizer = torch.optim.AdamW(params)
|
199 |
+
elif self.hparams.optimizer == "RAdam":
|
200 |
+
optimizer = torch.optim.RAdam(params)
|
201 |
+
|
202 |
+
warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
|
203 |
+
cycle_steps = self.hparams.max_epochs - warmup_steps
|
204 |
+
lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
|
205 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
206 |
+
return [optimizer], [scheduler]
|
207 |
+
|
208 |
+
def test_step(self, batch, batch_idx):
|
209 |
+
x = batch["image"]
|
210 |
+
feat1 = self.get_feat(x)
|
211 |
+
out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
|
212 |
+
feat2 = self.get_feat(x.flip(3))
|
213 |
+
out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
|
214 |
+
pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
|
215 |
+
return {
|
216 |
+
"original_index": batch["original_index"],
|
217 |
+
"label": batch["label"],
|
218 |
+
"label_species": batch["label_species"],
|
219 |
+
"pred_logit": pred_logit[:, :1000],
|
220 |
+
"pred_idx": pred_idx[:, :1000],
|
221 |
+
"pred_species": ((out_species1 + out_species2) / 2).cpu(),
|
222 |
+
"embed_features1": feat1.cpu(),
|
223 |
+
"embed_features2": feat2.cpu(),
|
224 |
+
}
|
225 |
+
|
226 |
+
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
|
227 |
+
outputs = self.all_gather(outputs)
|
228 |
+
if self.trainer.global_rank == 0:
|
229 |
+
epoch_results: Dict[str, np.ndarray] = {}
|
230 |
+
for key in outputs[0].keys():
|
231 |
+
if torch.cuda.device_count() > 1:
|
232 |
+
result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
|
233 |
+
else:
|
234 |
+
result = torch.cat([x[key] for x in outputs], dim=0)
|
235 |
+
epoch_results[key] = result.detach().cpu().numpy()
|
236 |
+
np.savez_compressed(self.test_results_fp, **epoch_results)
|
237 |
+
|
238 |
+
|
239 |
+
# def train(
|
240 |
+
# df: pd.DataFrame,
|
241 |
+
# args: argparse.Namespace,
|
242 |
+
# cfg: Config,
|
243 |
+
# fold: int,
|
244 |
+
# do_inference: bool = False,
|
245 |
+
# additional_dataset: WhaleDataset = None,
|
246 |
+
# optuna_trial: Optional[optuna.Trial] = None,
|
247 |
+
# ) -> Optional[float]:
|
248 |
+
# out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}"
|
249 |
+
# id_class_nums = df.individual_id.value_counts().sort_index().values
|
250 |
+
# species_class_nums = df.species.value_counts().sort_index().values
|
251 |
+
# model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums)
|
252 |
+
# data_module = WhaleDataModule(
|
253 |
+
# df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset
|
254 |
+
# )
|
255 |
+
# loggers = [pl_loggers.CSVLogger(out_dir)]
|
256 |
+
# if args.wandb_logger:
|
257 |
+
# loggers.append(
|
258 |
+
# pl_loggers.WandbLogger(
|
259 |
+
# project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir
|
260 |
+
# )
|
261 |
+
# )
|
262 |
+
# callbacks = [LearningRateMonitor("epoch")]
|
263 |
+
# if optuna_trial is not None:
|
264 |
+
# callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone"))
|
265 |
+
# if args.save_checkpoint:
|
266 |
+
# callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0))
|
267 |
+
# trainer = Trainer(
|
268 |
+
# gpus=torch.cuda.device_count(),
|
269 |
+
# max_epochs=cfg["max_epochs"],
|
270 |
+
# logger=loggers,
|
271 |
+
# callbacks=callbacks,
|
272 |
+
# checkpoint_callback=args.save_checkpoint,
|
273 |
+
# precision=16,
|
274 |
+
# sync_batchnorm=True,
|
275 |
+
# )
|
276 |
+
# ckpt_path = f"{out_dir}/last.ckpt"
|
277 |
+
# if not os.path.exists(ckpt_path) or not args.load_snapshot:
|
278 |
+
# ckpt_path = None
|
279 |
+
# trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module)
|
280 |
+
# if do_inference:
|
281 |
+
# for test_bbox in cfg.test_bboxes:
|
282 |
+
# # all train data
|
283 |
+
# model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz"
|
284 |
+
# trainer.test(model, data_module.all_dataloader())
|
285 |
+
# # test data
|
286 |
+
# model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz"
|
287 |
+
# df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False)
|
288 |
+
# test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1)
|
289 |
+
# trainer.test(model, test_data_module.all_dataloader())
|
290 |
+
|
291 |
+
# if args.wandb_logger:
|
292 |
+
# wandb.finish()
|
293 |
+
# if optuna_trial is not None:
|
294 |
+
# return trainer.callback_metrics["val/mapNone"].item()
|
295 |
+
# else:
|
296 |
+
# return None
|
297 |
+
|
298 |
+
|
299 |
+
# def main():
|
300 |
+
# args = parse()
|
301 |
+
# warnings.filterwarnings("ignore", ".*does not have many workers.*")
|
302 |
+
# cfg = load_config(args.config_path, "config/default.yaml")
|
303 |
+
# print(cfg)
|
304 |
+
# df = load_df(args.in_base_dir, cfg, "train.csv", True)
|
305 |
+
# pseudo_dataset = None
|
306 |
+
# if cfg.pseudo_label is not None:
|
307 |
+
# pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False)
|
308 |
+
# pseudo_dataset = WhaleDataset(
|
309 |
+
# pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True
|
310 |
+
# )
|
311 |
+
# if cfg["n_splits"] == -1:
|
312 |
+
# train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset)
|
313 |
+
# else:
|
314 |
+
# train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset)
|
315 |
+
|
316 |
+
|
317 |
+
# if __name__ == "__main__":
|
318 |
+
# main()
|
utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class WarmupCosineLambda:
|
8 |
+
def __init__(self, warmup_steps: int, cycle_steps: int, decay_scale: float, exponential_warmup: bool = False):
|
9 |
+
self.warmup_steps = warmup_steps
|
10 |
+
self.cycle_steps = cycle_steps
|
11 |
+
self.decay_scale = decay_scale
|
12 |
+
self.exponential_warmup = exponential_warmup
|
13 |
+
|
14 |
+
def __call__(self, epoch: int):
|
15 |
+
if epoch < self.warmup_steps:
|
16 |
+
if self.exponential_warmup:
|
17 |
+
return self.decay_scale * pow(self.decay_scale, -epoch / self.warmup_steps)
|
18 |
+
ratio = epoch / self.warmup_steps
|
19 |
+
else:
|
20 |
+
ratio = (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.cycle_steps)) / 2
|
21 |
+
return self.decay_scale + (1 - self.decay_scale) * ratio
|
22 |
+
|
23 |
+
|
24 |
+
def topk_average_precision(output: torch.Tensor, y: torch.Tensor, k: int):
|
25 |
+
score_array = torch.tensor([1.0 / i for i in range(1, k + 1)], device=output.device)
|
26 |
+
topk = output.topk(k)[1]
|
27 |
+
match_mat = topk == y[:, None].expand(topk.shape)
|
28 |
+
return (match_mat * score_array).sum(dim=1)
|
29 |
+
|
30 |
+
|
31 |
+
def calc_map5(output: torch.Tensor, y: torch.Tensor, threshold: Optional[float]):
|
32 |
+
if threshold is not None:
|
33 |
+
output = torch.cat([output, torch.full((output.shape[0], 1), threshold, device=output.device)], dim=1)
|
34 |
+
return topk_average_precision(output, y, 5).mean().detach()
|
35 |
+
|
36 |
+
|
37 |
+
def map_dict(output: torch.Tensor, y: torch.Tensor, prefix: str):
|
38 |
+
d = {f"{prefix}/acc": topk_average_precision(output, y, 1).mean().detach()}
|
39 |
+
for threshold in [None, 0.3, 0.4, 0.5, 0.6, 0.7]:
|
40 |
+
d[f"{prefix}/map{threshold}"] = calc_map5(output, y, threshold)
|
41 |
+
return d
|