Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
MalloryWittwerEPFL commited on
Commit
d514464
1 Parent(s): ea195e7

Upload model

Browse files
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import yaml
4
+
5
+
6
+ class Config(dict):
7
+ def __getattr__(self, key):
8
+ try:
9
+ val = self[key]
10
+ except KeyError:
11
+ return super().__getattr__(key)
12
+ if isinstance(val, dict):
13
+ return Config(val)
14
+ return val
15
+
16
+
17
+ def load_config(path: str, default_path: Optional[str]) -> Config:
18
+ with open(path) as f:
19
+ cfg = Config(yaml.full_load(f))
20
+ if default_path is not None:
21
+ # set keys not included in `path` by default
22
+ with open(default_path) as f:
23
+ default_cfg = Config(yaml.full_load(f))
24
+ for key, val in default_cfg.items():
25
+ if key not in cfg:
26
+ print(f"used default config {key}: {val}")
27
+ cfg[key] = val
28
+ return cfg
configuration_cetacean_classifier.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class CetaceanClassifierConfig(PretrainedConfig):
6
+ model_type = "cetaceanet"
7
+
8
+ def __init__(
9
+ self,
10
+ **kwargs,
11
+ ):
12
+ super().__init__(**kwargs)
metric_learning.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class GeM(nn.Module):
10
+ def __init__(self, p=3, eps=1e-6, requires_grad=False):
11
+ super().__init__()
12
+ self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad)
13
+ self.eps = eps
14
+
15
+ def forward(self, x: torch.Tensor):
16
+ return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p)
17
+
18
+
19
+ # Copied and modified from
20
+ # https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py
21
+ class ArcMarginProductSubcenter(nn.Module):
22
+ def __init__(self, in_features: int, out_features: int, k: int = 3):
23
+ super().__init__()
24
+ self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features))
25
+ self.reset_parameters()
26
+ self.k = k
27
+ self.out_features = out_features
28
+
29
+ def reset_parameters(self):
30
+ stdv = 1.0 / math.sqrt(self.weight.size(1))
31
+ self.weight.data.uniform_(-stdv, stdv)
32
+
33
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
34
+ cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
35
+ cosine_all = cosine_all.view(-1, self.out_features, self.k)
36
+ cosine, _ = torch.max(cosine_all, dim=2)
37
+ return cosine
38
+
39
+
40
+ class ArcFaceLossAdaptiveMargin(nn.modules.Module):
41
+ def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0):
42
+ super().__init__()
43
+ self.s = s
44
+ self.margins = margins
45
+ self.out_dim = n_classes
46
+
47
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
48
+ ms = self.margins[labels.cpu().numpy()]
49
+ cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
50
+ sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
51
+ th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
52
+ mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
53
+ labels = F.one_hot(labels, self.out_dim).float()
54
+ logits = logits.float()
55
+ cosine = logits
56
+ sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
57
+ phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
58
+ phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
59
+ return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd07eaf5ca7a871b9162257e44e9bade9312ef1378ea1d77476b35253dda14dd
3
  size 296028464
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:648dd257e82e5d02a1e649cfb3193554c096e40e520f903efe150d746ddd70fa
3
  size 296028464
modeling_cetacean_classifier.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import albumentations as A
2
+ from transformers import PreTrainedModel
3
+ # from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+ from .train import SphereClassifier
9
+ from .configuration_cetacean_classifier import CetaceanClassifierConfig
10
+
11
+
12
+ WHALE_CLASSES = np.array(
13
+ [
14
+ "beluga",
15
+ "blue_whale",
16
+ "bottlenose_dolphin",
17
+ "brydes_whale",
18
+ "commersons_dolphin",
19
+ "common_dolphin",
20
+ "cuviers_beaked_whale",
21
+ "dusky_dolphin",
22
+ "false_killer_whale",
23
+ "fin_whale",
24
+ "frasiers_dolphin",
25
+ "gray_whale",
26
+ "humpback_whale",
27
+ "killer_whale",
28
+ "long_finned_pilot_whale",
29
+ "melon_headed_whale",
30
+ "minke_whale",
31
+ "pantropic_spotted_dolphin",
32
+ "pygmy_killer_whale",
33
+ "rough_toothed_dolphin",
34
+ "sei_whale",
35
+ "short_finned_pilot_whale",
36
+ "southern_right_whale",
37
+ "spinner_dolphin",
38
+ "spotted_dolphin",
39
+ "white_sided_dolphin",
40
+ ]
41
+ )
42
+
43
+
44
+ class CetaceanClassifierModelForImageClassification(PreTrainedModel):
45
+ config_class = CetaceanClassifierConfig
46
+
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+
50
+ self.model = SphereClassifier(cfg=config.to_dict())
51
+
52
+ # load_from_checkpoint("cetacean_classifier/last.ckpt")
53
+ # self.model = SphereClassifier.load_from_checkpoint("cetacean_classifier/last.ckpt")
54
+
55
+ self.model.eval()
56
+ self.config = config
57
+ self.transforms = self.make_transforms(data_aug=True)
58
+
59
+ def make_transforms(self, data_aug: bool):
60
+ augments = []
61
+ if data_aug:
62
+ aug = self.config.aug
63
+ augments = [
64
+ A.RandomResizedCrop(
65
+ self.config.image_size[0],
66
+ self.config.image_size[1],
67
+ scale=(aug["crop_scale"], 1.0),
68
+ ratio=(aug["crop_l"], aug["crop_r"]),
69
+ ),]
70
+ return A.Compose(augments)
71
+
72
+ def preprocess_image(self, img) -> torch.Tensor:
73
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
74
+ image = cv2.resize(rgb, self.config.image_size, interpolation=cv2.INTER_CUBIC)
75
+ image = self.transforms(image=image)["image"]
76
+ return torch.Tensor(image).transpose(2, 0).unsqueeze(0)
77
+ #image_resized = img.resize((480, 480))
78
+ #image_resized = np.array(image_resized)[None]
79
+ #image_resized = np.transpose(image_resized, [0, 3, 2, 1])
80
+ #image_tensor = torch.Tensor(image_resized)
81
+ #return image_tensor
82
+
83
+ def forward(self, img, labels=None):
84
+ tensor = self.preprocess_image(img)
85
+ head_id_logits, head_species_logits = self.model(tensor)
86
+ head_species_logits = head_species_logits.detach().numpy()
87
+ sorted_idx = head_species_logits.argsort()[0]
88
+ sorted_idx = np.array(list(reversed(sorted_idx)))
89
+ top_three_logits = sorted_idx[:3]
90
+ top_three_whale_preds = WHALE_CLASSES[top_three_logits]
91
+
92
+ return {"predictions": top_three_whale_preds}
train.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import timm
5
+ import torch
6
+ from pytorch_lightning import LightningDataModule, LightningModule, Trainer
7
+
8
+ from .config import Config, load_config
9
+ # from .dataset import WhaleDataset, load_df
10
+ from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
11
+ from .utils import WarmupCosineLambda, map_dict, topk_average_precision
12
+
13
+
14
+ class SphereClassifier(LightningModule):
15
+ def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
16
+ super().__init__()
17
+ # import pdb; pdb.set_trace()
18
+ if not isinstance(cfg, Config):
19
+ cfg = Config(cfg)
20
+ self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
21
+ self.test_results_fp = None
22
+
23
+ # import json
24
+ # cfg_json = json.dumps(cfg)
25
+ # with open("config_extracted.json", "w") as file:
26
+ # file.write(cfg_json)
27
+
28
+ # NN architecture
29
+ self.backbone = timm.create_model(
30
+ cfg.model_name,
31
+ in_chans=3,
32
+ pretrained=cfg.pretrained,
33
+ num_classes=0,
34
+ features_only=True,
35
+ out_indices=cfg.out_indices,
36
+ )
37
+ feature_dims = self.backbone.feature_info.channels()
38
+ print(f"feature dims: {feature_dims}")
39
+ self.global_pools = torch.nn.ModuleList(
40
+ [GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
41
+ )
42
+ self.mid_features = np.sum(feature_dims)
43
+ if cfg.normalization == "batchnorm":
44
+ self.neck = torch.nn.BatchNorm1d(self.mid_features)
45
+ elif cfg.normalization == "layernorm":
46
+ self.neck = torch.nn.LayerNorm(self.mid_features)
47
+ self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
48
+ self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
49
+ if id_class_nums is not None and species_class_nums is not None:
50
+ margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
51
+ margins_species = (
52
+ np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
53
+ + cfg.margin_cons_species
54
+ )
55
+ print("margins_id", margins_id)
56
+ print("margins_species", margins_species)
57
+ self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
58
+ self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
59
+ self.loss_fn_id = torch.nn.CrossEntropyLoss()
60
+ self.loss_fn_species = torch.nn.CrossEntropyLoss()
61
+
62
+ def get_feat(self, x: torch.Tensor) -> torch.Tensor:
63
+ ms = self.backbone(x)
64
+ h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
65
+ return self.neck(h)
66
+
67
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ feat = self.get_feat(x)
69
+ return self.head_id(feat), self.head_species(feat)
70
+
71
+ def training_step(self, batch, batch_idx):
72
+ x, ids, species = batch["image"], batch["label"], batch["label_species"]
73
+ logits_ids, logits_species = self(x)
74
+ margin_logits_ids = self.margin_fn_id(logits_ids, ids)
75
+ loss_ids = self.loss_fn_id(margin_logits_ids, ids)
76
+ loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
77
+ self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
78
+ self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
79
+ with torch.no_grad():
80
+ self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
81
+ self.log_dict(
82
+ {"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
83
+ on_step=False,
84
+ on_epoch=True,
85
+ )
86
+ return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
87
+
88
+ def validation_step(self, batch, batch_idx):
89
+ x, ids, species = batch["image"], batch["label"], batch["label_species"]
90
+ out1, out_species1 = self(x)
91
+ out2, out_species2 = self(x.flip(3))
92
+ output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
93
+ self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
94
+ self.log_dict(
95
+ {"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
96
+ on_step=False,
97
+ on_epoch=True,
98
+ )
99
+
100
+ def configure_optimizers(self):
101
+ backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
102
+ head_params = (
103
+ list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
104
+ )
105
+ params = [
106
+ {"params": backbone_params, "lr": self.hparams.lr_backbone},
107
+ {"params": head_params, "lr": self.hparams.lr_head},
108
+ ]
109
+ if self.hparams.optimizer == "Adam":
110
+ optimizer = torch.optim.Adam(params)
111
+ elif self.hparams.optimizer == "AdamW":
112
+ optimizer = torch.optim.AdamW(params)
113
+ elif self.hparams.optimizer == "RAdam":
114
+ optimizer = torch.optim.RAdam(params)
115
+
116
+ warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
117
+ cycle_steps = self.hparams.max_epochs - warmup_steps
118
+ lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
119
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
120
+ return [optimizer], [scheduler]
121
+
122
+ def test_step(self, batch, batch_idx):
123
+ x = batch["image"]
124
+ feat1 = self.get_feat(x)
125
+ out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
126
+ feat2 = self.get_feat(x.flip(3))
127
+ out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
128
+ pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
129
+ return {
130
+ "original_index": batch["original_index"],
131
+ "label": batch["label"],
132
+ "label_species": batch["label_species"],
133
+ "pred_logit": pred_logit[:, :1000],
134
+ "pred_idx": pred_idx[:, :1000],
135
+ "pred_species": ((out_species1 + out_species2) / 2).cpu(),
136
+ "embed_features1": feat1.cpu(),
137
+ "embed_features2": feat2.cpu(),
138
+ }
139
+
140
+ def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
141
+ outputs = self.all_gather(outputs)
142
+ if self.trainer.global_rank == 0:
143
+ epoch_results: Dict[str, np.ndarray] = {}
144
+ for key in outputs[0].keys():
145
+ if torch.cuda.device_count() > 1:
146
+ result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
147
+ else:
148
+ result = torch.cat([x[key] for x in outputs], dim=0)
149
+ epoch_results[key] = result.detach().cpu().numpy()
150
+ np.savez_compressed(self.test_results_fp, **epoch_results)
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+
7
+ class WarmupCosineLambda:
8
+ def __init__(self, warmup_steps: int, cycle_steps: int, decay_scale: float, exponential_warmup: bool = False):
9
+ self.warmup_steps = warmup_steps
10
+ self.cycle_steps = cycle_steps
11
+ self.decay_scale = decay_scale
12
+ self.exponential_warmup = exponential_warmup
13
+
14
+ def __call__(self, epoch: int):
15
+ if epoch < self.warmup_steps:
16
+ if self.exponential_warmup:
17
+ return self.decay_scale * pow(self.decay_scale, -epoch / self.warmup_steps)
18
+ ratio = epoch / self.warmup_steps
19
+ else:
20
+ ratio = (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.cycle_steps)) / 2
21
+ return self.decay_scale + (1 - self.decay_scale) * ratio
22
+
23
+
24
+ def topk_average_precision(output: torch.Tensor, y: torch.Tensor, k: int):
25
+ score_array = torch.tensor([1.0 / i for i in range(1, k + 1)], device=output.device)
26
+ topk = output.topk(k)[1]
27
+ match_mat = topk == y[:, None].expand(topk.shape)
28
+ return (match_mat * score_array).sum(dim=1)
29
+
30
+
31
+ def calc_map5(output: torch.Tensor, y: torch.Tensor, threshold: Optional[float]):
32
+ if threshold is not None:
33
+ output = torch.cat([output, torch.full((output.shape[0], 1), threshold, device=output.device)], dim=1)
34
+ return topk_average_precision(output, y, 5).mean().detach()
35
+
36
+
37
+ def map_dict(output: torch.Tensor, y: torch.Tensor, prefix: str):
38
+ d = {f"{prefix}/acc": topk_average_precision(output, y, 1).mean().detach()}
39
+ for threshold in [None, 0.3, 0.4, 0.5, 0.6, 0.7]:
40
+ d[f"{prefix}/map{threshold}"] = calc_map5(output, y, threshold)
41
+ return d