File size: 2,001 Bytes
62c7319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from loguru import logger
import gc
import roma
class CheckPoint:
def __init__(self, dir=None, name="tmp"):
self.name = name
self.dir = dir
os.makedirs(self.dir, exist_ok=True)
def save(
self,
model,
optimizer,
lr_scheduler,
n,
):
if roma.RANK == 0:
assert model is not None
if isinstance(model, (DataParallel, DistributedDataParallel)):
model = model.module
states = {
"model": model.state_dict(),
"n": n,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
}
torch.save(states, self.dir + self.name + f"_latest.pth")
logger.info(f"Saved states {list(states.keys())}, at step {n}")
def load(
self,
model,
optimizer,
lr_scheduler,
n,
):
if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
states = torch.load(self.dir + self.name + f"_latest.pth")
if "model" in states:
model.load_state_dict(states["model"])
if "n" in states:
n = states["n"] if states["n"] else n
if "optimizer" in states:
try:
optimizer.load_state_dict(states["optimizer"])
except Exception as e:
print(f"Failed to load states for optimizer, with error {e}")
if "lr_scheduler" in states:
lr_scheduler.load_state_dict(states["lr_scheduler"])
print(f"Loaded states {list(states.keys())}, at step {n}")
del states
gc.collect()
torch.cuda.empty_cache()
return model, optimizer, lr_scheduler, n |