Image Classification
Transformers
Safetensors
cetaceanet
biology
biodiversity
custom_code
File size: 7,273 Bytes
d514464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import Dict, List, Optional, Tuple

import numpy as np
import timm
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer

from .config import Config, load_config
# from .dataset import WhaleDataset, load_df
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
from .utils import WarmupCosineLambda, map_dict, topk_average_precision


class SphereClassifier(LightningModule):
    def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
        super().__init__()
        # import pdb; pdb.set_trace()
        if not isinstance(cfg, Config):
            cfg = Config(cfg)
        self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
        self.test_results_fp = None

        # import json
        # cfg_json = json.dumps(cfg)
        # with open("config_extracted.json", "w") as file:
        #     file.write(cfg_json)

        # NN architecture
        self.backbone = timm.create_model(
            cfg.model_name,
            in_chans=3,
            pretrained=cfg.pretrained,
            num_classes=0,
            features_only=True,
            out_indices=cfg.out_indices,
        )
        feature_dims = self.backbone.feature_info.channels()
        print(f"feature dims: {feature_dims}")
        self.global_pools = torch.nn.ModuleList(
            [GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
        )
        self.mid_features = np.sum(feature_dims)
        if cfg.normalization == "batchnorm":
            self.neck = torch.nn.BatchNorm1d(self.mid_features)
        elif cfg.normalization == "layernorm":
            self.neck = torch.nn.LayerNorm(self.mid_features)
        self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
        self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
        if id_class_nums is not None and species_class_nums is not None:
            margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
            margins_species = (
                np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
                + cfg.margin_cons_species
            )
            print("margins_id", margins_id)
            print("margins_species", margins_species)
            self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
            self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
            self.loss_fn_id = torch.nn.CrossEntropyLoss()
            self.loss_fn_species = torch.nn.CrossEntropyLoss()

    def get_feat(self, x: torch.Tensor) -> torch.Tensor:
        ms = self.backbone(x)
        h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
        return self.neck(h)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        feat = self.get_feat(x)
        return self.head_id(feat), self.head_species(feat)

    def training_step(self, batch, batch_idx):
        x, ids, species = batch["image"], batch["label"], batch["label_species"]
        logits_ids, logits_species = self(x)
        margin_logits_ids = self.margin_fn_id(logits_ids, ids)
        loss_ids = self.loss_fn_id(margin_logits_ids, ids)
        loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
        self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
        self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
        with torch.no_grad():
            self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
            self.log_dict(
                {"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
                on_step=False,
                on_epoch=True,
            )
        return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)

    def validation_step(self, batch, batch_idx):
        x, ids, species = batch["image"], batch["label"], batch["label_species"]
        out1, out_species1 = self(x)
        out2, out_species2 = self(x.flip(3))
        output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
        self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
        self.log_dict(
            {"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
            on_step=False,
            on_epoch=True,
        )

    def configure_optimizers(self):
        backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
        head_params = (
            list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
        )
        params = [
            {"params": backbone_params, "lr": self.hparams.lr_backbone},
            {"params": head_params, "lr": self.hparams.lr_head},
        ]
        if self.hparams.optimizer == "Adam":
            optimizer = torch.optim.Adam(params)
        elif self.hparams.optimizer == "AdamW":
            optimizer = torch.optim.AdamW(params)
        elif self.hparams.optimizer == "RAdam":
            optimizer = torch.optim.RAdam(params)

        warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
        cycle_steps = self.hparams.max_epochs - warmup_steps
        lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        return [optimizer], [scheduler]

    def test_step(self, batch, batch_idx):
        x = batch["image"]
        feat1 = self.get_feat(x)
        out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
        feat2 = self.get_feat(x.flip(3))
        out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
        pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
        return {
            "original_index": batch["original_index"],
            "label": batch["label"],
            "label_species": batch["label_species"],
            "pred_logit": pred_logit[:, :1000],
            "pred_idx": pred_idx[:, :1000],
            "pred_species": ((out_species1 + out_species2) / 2).cpu(),
            "embed_features1": feat1.cpu(),
            "embed_features2": feat2.cpu(),
        }

    def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
        outputs = self.all_gather(outputs)
        if self.trainer.global_rank == 0:
            epoch_results: Dict[str, np.ndarray] = {}
            for key in outputs[0].keys():
                if torch.cuda.device_count() > 1:
                    result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
                else:
                    result = torch.cat([x[key] for x in outputs], dim=0)
                epoch_results[key] = result.detach().cpu().numpy()
            np.savez_compressed(self.test_results_fp, **epoch_results)