|
from typing import TYPE_CHECKING, Optional, Union |
|
from easydict import EasyDict |
|
import os |
|
import numpy as np |
|
|
|
from ding.utils import save_file |
|
from ding.policy import Policy |
|
from ding.framework import task |
|
|
|
if TYPE_CHECKING: |
|
from ding.framework import OnlineRLContext, OfflineRLContext |
|
|
|
|
|
class CkptSaver: |
|
""" |
|
Overview: |
|
The class used to save checkpoint data. |
|
""" |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)): |
|
return task.void() |
|
return super(CkptSaver, cls).__new__(cls) |
|
|
|
def __init__(self, policy: Policy, save_dir: str, train_freq: Optional[int] = None, save_finish: bool = True): |
|
""" |
|
Overview: |
|
Initialize the `CkptSaver`. |
|
Arguments: |
|
- policy (:obj:`Policy`): Policy used to save the checkpoint. |
|
- save_dir (:obj:`str`): The directory path to save ckpt. |
|
- train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data. |
|
- save_finish (:obj:`bool`): Whether save final ckpt when ``task.finish = True``. |
|
""" |
|
self.policy = policy |
|
self.train_freq = train_freq |
|
if str(os.path.basename(os.path.normpath(save_dir))) != "ckpt": |
|
self.prefix = '{}/ckpt'.format(os.path.normpath(save_dir)) |
|
else: |
|
self.prefix = '{}/'.format(os.path.normpath(save_dir)) |
|
if not os.path.exists(self.prefix): |
|
os.makedirs(self.prefix) |
|
self.last_save_iter = 0 |
|
self.max_eval_value = -np.inf |
|
self.save_finish = save_finish |
|
|
|
def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None: |
|
""" |
|
Overview: |
|
The method used to save checkpoint data. \ |
|
The checkpoint data will be saved in a file in following 3 cases: \ |
|
- When a multiple of `self.train_freq` iterations have elapsed since the beginning of training; \ |
|
- When the evaluation episode return is the best so far; \ |
|
- When `task.finish` is True. |
|
Input of ctx: |
|
- train_iter (:obj:`int`): Number of training iteration, i.e. the number of updating policy related network. |
|
- eval_value (:obj:`float`): The episode return of current iteration. |
|
""" |
|
|
|
if self.train_freq: |
|
if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq: |
|
save_file( |
|
"{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict() |
|
) |
|
self.last_save_iter = ctx.train_iter |
|
|
|
|
|
if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value: |
|
save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) |
|
self.max_eval_value = ctx.eval_value |
|
|
|
|
|
if task.finish and self.save_finish: |
|
save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict()) |
|
|