File size: 7,273 Bytes
d514464 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import Dict, List, Optional, Tuple
import numpy as np
import timm
import torch
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from .config import Config, load_config
# from .dataset import WhaleDataset, load_df
from .metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM
from .utils import WarmupCosineLambda, map_dict, topk_average_precision
class SphereClassifier(LightningModule):
def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None):
super().__init__()
# import pdb; pdb.set_trace()
if not isinstance(cfg, Config):
cfg = Config(cfg)
self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"])
self.test_results_fp = None
# import json
# cfg_json = json.dumps(cfg)
# with open("config_extracted.json", "w") as file:
# file.write(cfg_json)
# NN architecture
self.backbone = timm.create_model(
cfg.model_name,
in_chans=3,
pretrained=cfg.pretrained,
num_classes=0,
features_only=True,
out_indices=cfg.out_indices,
)
feature_dims = self.backbone.feature_info.channels()
print(f"feature dims: {feature_dims}")
self.global_pools = torch.nn.ModuleList(
[GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices]
)
self.mid_features = np.sum(feature_dims)
if cfg.normalization == "batchnorm":
self.neck = torch.nn.BatchNorm1d(self.mid_features)
elif cfg.normalization == "layernorm":
self.neck = torch.nn.LayerNorm(self.mid_features)
self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id)
self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species)
if id_class_nums is not None and species_class_nums is not None:
margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id
margins_species = (
np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species
+ cfg.margin_cons_species
)
print("margins_id", margins_id)
print("margins_species", margins_species)
self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id)
self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species)
self.loss_fn_id = torch.nn.CrossEntropyLoss()
self.loss_fn_species = torch.nn.CrossEntropyLoss()
def get_feat(self, x: torch.Tensor) -> torch.Tensor:
ms = self.backbone(x)
h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
return self.neck(h)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
feat = self.get_feat(x)
return self.head_id(feat), self.head_species(feat)
def training_step(self, batch, batch_idx):
x, ids, species = batch["image"], batch["label"], batch["label_species"]
logits_ids, logits_species = self(x)
margin_logits_ids = self.margin_fn_id(logits_ids, ids)
loss_ids = self.loss_fn_id(margin_logits_ids, ids)
loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species)
self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True)
with torch.no_grad():
self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
self.log_dict(
{"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()},
on_step=False,
on_epoch=True,
)
return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio)
def validation_step(self, batch, batch_idx):
x, ids, species = batch["image"], batch["label"], batch["label_species"]
out1, out_species1 = self(x)
out2, out_species2 = self(x.flip(3))
output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2
self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)
self.log_dict(
{"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()},
on_step=False,
on_epoch=True,
)
def configure_optimizers(self):
backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
head_params = (
list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters())
)
params = [
{"params": backbone_params, "lr": self.hparams.lr_backbone},
{"params": head_params, "lr": self.hparams.lr_head},
]
if self.hparams.optimizer == "Adam":
optimizer = torch.optim.Adam(params)
elif self.hparams.optimizer == "AdamW":
optimizer = torch.optim.AdamW(params)
elif self.hparams.optimizer == "RAdam":
optimizer = torch.optim.RAdam(params)
warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio
cycle_steps = self.hparams.max_epochs - warmup_steps
lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return [optimizer], [scheduler]
def test_step(self, batch, batch_idx):
x = batch["image"]
feat1 = self.get_feat(x)
out1, out_species1 = self.head_id(feat1), self.head_species(feat1)
feat2 = self.get_feat(x.flip(3))
out2, out_species2 = self.head_id(feat2), self.head_species(feat2)
pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True)
return {
"original_index": batch["original_index"],
"label": batch["label"],
"label_species": batch["label_species"],
"pred_logit": pred_logit[:, :1000],
"pred_idx": pred_idx[:, :1000],
"pred_species": ((out_species1 + out_species2) / 2).cpu(),
"embed_features1": feat1.cpu(),
"embed_features2": feat2.cpu(),
}
def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]):
outputs = self.all_gather(outputs)
if self.trainer.global_rank == 0:
epoch_results: Dict[str, np.ndarray] = {}
for key in outputs[0].keys():
if torch.cuda.device_count() > 1:
result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1)
else:
result = torch.cat([x[key] for x in outputs], dim=0)
epoch_results[key] = result.detach().cpu().numpy()
np.savez_compressed(self.test_results_fp, **epoch_results)
|