Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
from pathlib import Path | |
import pytorch_lightning as pl | |
import torch | |
from omegaconf import DictConfig, OmegaConf, open_dict | |
from torchmetrics import MeanMetric, MetricCollection | |
import logger | |
from models import get_model | |
class AverageKeyMeter(MeanMetric): | |
def __init__(self, key, *args, **kwargs): | |
self.key = key | |
super().__init__(*args, **kwargs) | |
def update(self, dict): | |
value = dict[self.key] | |
value = value[torch.isfinite(value)] | |
return super().update(value) | |
class GenericModule(pl.LightningModule): | |
def __init__(self, cfg): | |
super().__init__() | |
name = cfg.model.get("name") | |
name = "orienternet" if name in ("localizer_bev_depth", None) else name | |
self.model = get_model(name)(cfg.model) | |
self.cfg = cfg | |
self.save_hyperparameters(cfg) | |
self.metrics_val = MetricCollection(self.model.metrics(), prefix="val/") | |
self.losses_val = None # we do not know the loss keys in advance | |
# self.citys = self.cfg.data.val_citys | |
# for i in range(len(self.citys)): | |
# city=self.citys[i] | |
# setattr(self, "metric_vals_{}".format(i), MetricCollection(self.model.metrics(), prefix="val_{}/".format(city))) | |
# self.losse_vals = [None for city in self.cfg.data.val_citys] | |
def forward(self, batch): | |
return self.model(batch) | |
def training_step(self, batch): | |
pred = self(batch) | |
losses = self.model.loss(pred, batch) | |
self.log_dict( | |
{f"loss/{k}/train": v.mean() for k, v in losses.items()}, | |
prog_bar=True, | |
rank_zero_only=True, | |
) | |
return losses["total"].mean() | |
# def validation_step(self, batch, batch_idx,dataloader_idx): | |
# city=self.citys[dataloader_idx] | |
# | |
# pred = self(batch) | |
# losses = self.model.loss(pred, batch) | |
# | |
# if hasattr(self,"losse_val_{}".format(dataloader_idx)) is False: | |
# setattr(self,"losse_val_{}".format(dataloader_idx),MetricCollection( | |
# {k: AverageKeyMeter(k).to(self.device) for k in losses}, | |
# prefix="loss_{}/".format(city), | |
# postfix="/val_{}".format(city), | |
# )) | |
# | |
# # print(pred, batch) | |
# getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch) | |
# self.log_dict(getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch), sync_dist=True) | |
# | |
# getattr(self,"losse_val_{}".format(dataloader_idx)).update(losses) | |
# # print(getattr(self,"losse_val_{}".format(dataloader_idx))) | |
# self.log_dict(getattr(self,"losse_val_{}".format(dataloader_idx)).compute(), sync_dist=True) | |
def validation_step(self, batch, batch_idx): | |
pred = self(batch) | |
losses = self.model.loss(pred, batch) | |
if self.losses_val is None: | |
self.losses_val = MetricCollection( | |
{k: AverageKeyMeter(k).to(self.device) for k in losses}, | |
prefix="loss/", | |
postfix="/val", | |
) | |
self.metrics_val(pred, batch) | |
self.log_dict(self.metrics_val, sync_dist=True) | |
self.losses_val.update(losses) | |
self.log_dict(self.losses_val, sync_dist=True) | |
def validation_epoch_start(self, batch): | |
self.losses_val = None | |
# self.losse_val = [None for city in self.cfg.data.val_citys] | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.training.lr) | |
ret = {"optimizer": optimizer} | |
cfg_scheduler = self.cfg.training.get("lr_scheduler") | |
if cfg_scheduler is not None: | |
scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)( | |
optimizer=optimizer, **cfg_scheduler.get("args", {}) | |
) | |
ret["lr_scheduler"] = { | |
"scheduler": scheduler, | |
"interval": "epoch", | |
"frequency": 1, | |
"monitor": "loss/total/val", | |
"strict": True, | |
"name": "learning_rate", | |
} | |
return ret | |
def load_from_checkpoint( | |
cls, | |
checkpoint_path, | |
map_location=None, | |
hparams_file=None, | |
strict=True, | |
cfg=None, | |
find_best=False, | |
): | |
assert hparams_file is None, "hparams are not supported." | |
checkpoint = torch.load( | |
checkpoint_path, map_location=map_location or (lambda storage, loc: storage) | |
) | |
if find_best: | |
best_score, best_name = None, None | |
modes = {"min": torch.lt, "max": torch.gt} | |
for key, state in checkpoint["callbacks"].items(): | |
if not key.startswith("ModelCheckpoint"): | |
continue | |
mode = eval(key.replace("ModelCheckpoint", ""))["mode"] | |
if best_score is None or modes[mode]( | |
state["best_model_score"], best_score | |
): | |
best_score = state["best_model_score"] | |
best_name = Path(state["best_model_path"]).name | |
logger.info("Loading best checkpoint %s", best_name) | |
if best_name != checkpoint_path: | |
return cls.load_from_checkpoint( | |
Path(checkpoint_path).parent / best_name, | |
map_location, | |
hparams_file, | |
strict, | |
cfg, | |
find_best=False, | |
) | |
logger.info( | |
"Using checkpoint %s from epoch %d and step %d.", | |
checkpoint_path.name, | |
checkpoint["epoch"], | |
checkpoint["global_step"], | |
) | |
cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] | |
if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility | |
cfg_ckpt = cfg_ckpt["cfg"] | |
cfg_ckpt = OmegaConf.create(cfg_ckpt) | |
if cfg is None: | |
cfg = {} | |
if not isinstance(cfg, DictConfig): | |
cfg = OmegaConf.create(cfg) | |
with open_dict(cfg_ckpt): | |
cfg = OmegaConf.merge(cfg_ckpt, cfg) | |
return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg) | |