|
import os |
|
import torch |
|
from torch.nn.parallel.data_parallel import DataParallel |
|
from torch.nn.parallel.distributed import DistributedDataParallel |
|
from loguru import logger |
|
|
|
|
|
class CheckPoint: |
|
def __init__(self, dir=None, name="tmp"): |
|
self.name = name |
|
self.dir = dir |
|
os.makedirs(self.dir, exist_ok=True) |
|
|
|
def __call__( |
|
self, |
|
model, |
|
optimizer, |
|
lr_scheduler, |
|
n, |
|
): |
|
assert model is not None |
|
if isinstance(model, (DataParallel, DistributedDataParallel)): |
|
model = model.module |
|
states = { |
|
"model": model.state_dict(), |
|
"n": n, |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
} |
|
torch.save(states, self.dir + self.name + f"_latest.pth") |
|
logger.info(f"Saved states {list(states.keys())}, at step {n}") |
|
|