|
import os |
|
import torch |
|
import torch.nn as nn |
|
from mono.utils.comm import main_process |
|
import copy |
|
import inspect |
|
import logging |
|
import glob |
|
|
|
|
|
def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None): |
|
""" |
|
Load the check point for resuming training or finetuning. |
|
""" |
|
logger = logging.getLogger() |
|
if os.path.isfile(load_path): |
|
if main_process(): |
|
logger.info(f"Loading weight '{load_path}'") |
|
checkpoint = torch.load(load_path, map_location="cpu") |
|
ckpt_state_dict = checkpoint['model_state_dict'] |
|
model.module.load_state_dict(ckpt_state_dict, strict=strict_match) |
|
|
|
if optimizer is not None: |
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
if scheduler is not None: |
|
scheduler.load_state_dict(checkpoint['scheduler']) |
|
if loss_scaler is not None and 'scaler' in checkpoint: |
|
scheduler.load_state_dict(checkpoint['scaler']) |
|
del ckpt_state_dict |
|
del checkpoint |
|
if main_process(): |
|
logger.info(f"Successfully loaded weight: '{load_path}'") |
|
if scheduler is not None and optimizer is not None: |
|
logger.info(f"Resume training from: '{load_path}'") |
|
else: |
|
if main_process(): |
|
raise RuntimeError(f"No weight found at '{load_path}'") |
|
return model, optimizer, scheduler, loss_scaler |
|
|
|
|
|
def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None): |
|
""" |
|
Save the model, optimizer, lr scheduler. |
|
""" |
|
logger = logging.getLogger() |
|
|
|
if 'IterBasedRunner' in cfg.runner.type: |
|
max_iters = cfg.runner.max_iters |
|
elif 'EpochBasedRunner' in cfg.runner.type: |
|
max_iters = cfg.runner.max_epochs |
|
else: |
|
raise TypeError(f'{cfg.runner.type} is not supported') |
|
|
|
ckpt = dict( |
|
model_state_dict=model.module.state_dict(), |
|
optimizer=optimizer.state_dict(), |
|
max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \ |
|
else cfg.runner.max_epochs, |
|
scheduler=scheduler.state_dict(), |
|
) |
|
|
|
if loss_scaler is not None: |
|
ckpt.update(dict(scaler=loss_scaler.state_dict())) |
|
|
|
ckpt_dir = os.path.join(cfg.work_dir, 'ckpt') |
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
|
save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter) |
|
saved_ckpts = glob.glob(ckpt_dir + '/step*.pth') |
|
torch.save(ckpt, save_name) |
|
|
|
|
|
if len(saved_ckpts) > 20: |
|
saved_ckpts.sort() |
|
os.remove(saved_ckpts.pop(0)) |
|
|
|
logger.info(f'Save model: {save_name}') |
|
|