zjowowen's picture
init space
079c32c
raw
history blame
3.19 kB
from typing import TYPE_CHECKING, Optional, Union
from easydict import EasyDict
import os
import numpy as np
from ding.utils import save_file
from ding.policy import Policy
from ding.framework import task
if TYPE_CHECKING:
from ding.framework import OnlineRLContext, OfflineRLContext
class CkptSaver:
"""
Overview:
The class used to save checkpoint data.
"""
def __new__(cls, *args, **kwargs):
if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)):
return task.void()
return super(CkptSaver, cls).__new__(cls)
def __init__(self, policy: Policy, save_dir: str, train_freq: Optional[int] = None, save_finish: bool = True):
"""
Overview:
Initialize the `CkptSaver`.
Arguments:
- policy (:obj:`Policy`): Policy used to save the checkpoint.
- save_dir (:obj:`str`): The directory path to save ckpt.
- train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data.
- save_finish (:obj:`bool`): Whether save final ckpt when ``task.finish = True``.
"""
self.policy = policy
self.train_freq = train_freq
if str(os.path.basename(os.path.normpath(save_dir))) != "ckpt":
self.prefix = '{}/ckpt'.format(os.path.normpath(save_dir))
else:
self.prefix = '{}/'.format(os.path.normpath(save_dir))
if not os.path.exists(self.prefix):
os.makedirs(self.prefix)
self.last_save_iter = 0
self.max_eval_value = -np.inf
self.save_finish = save_finish
def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None:
"""
Overview:
The method used to save checkpoint data. \
The checkpoint data will be saved in a file in following 3 cases: \
- When a multiple of `self.train_freq` iterations have elapsed since the beginning of training; \
- When the evaluation episode return is the best so far; \
- When `task.finish` is True.
Input of ctx:
- train_iter (:obj:`int`): Number of training iteration, i.e. the number of updating policy related network.
- eval_value (:obj:`float`): The episode return of current iteration.
"""
# train enough iteration
if self.train_freq:
if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq:
save_file(
"{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict()
)
self.last_save_iter = ctx.train_iter
# best episode return so far
if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value:
save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict())
self.max_eval_value = ctx.eval_value
# finish
if task.finish and self.save_finish:
save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict())