File size: 3,193 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import TYPE_CHECKING, Optional, Union
from easydict import EasyDict
import os
import numpy as np

from ding.utils import save_file
from ding.policy import Policy
from ding.framework import task

if TYPE_CHECKING:
    from ding.framework import OnlineRLContext, OfflineRLContext


class CkptSaver:
    """
        Overview:
            The class used to save checkpoint data.
    """

    def __new__(cls, *args, **kwargs):
        if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)):
            return task.void()
        return super(CkptSaver, cls).__new__(cls)

    def __init__(self, policy: Policy, save_dir: str, train_freq: Optional[int] = None, save_finish: bool = True):
        """
        Overview:
            Initialize the `CkptSaver`.
        Arguments:
            - policy (:obj:`Policy`): Policy used to save the checkpoint.
            - save_dir (:obj:`str`): The directory path to save ckpt.
            - train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data.
            - save_finish (:obj:`bool`): Whether save final ckpt when ``task.finish = True``.
        """
        self.policy = policy
        self.train_freq = train_freq
        if str(os.path.basename(os.path.normpath(save_dir))) != "ckpt":
            self.prefix = '{}/ckpt'.format(os.path.normpath(save_dir))
        else:
            self.prefix = '{}/'.format(os.path.normpath(save_dir))
        if not os.path.exists(self.prefix):
            os.makedirs(self.prefix)
        self.last_save_iter = 0
        self.max_eval_value = -np.inf
        self.save_finish = save_finish

    def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None:
        """
        Overview:
            The method used to save checkpoint data. \
            The checkpoint data will be saved in a file in following 3 cases: \
                - When a multiple of `self.train_freq` iterations have elapsed since the beginning of training; \
                - When the evaluation episode return is the best so far; \
                - When `task.finish` is True.
        Input of ctx:
            - train_iter (:obj:`int`): Number of training iteration, i.e. the number of updating policy related network.
            - eval_value (:obj:`float`): The episode return of current iteration.
        """
        # train enough iteration
        if self.train_freq:
            if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq:
                save_file(
                    "{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict()
                )
                self.last_save_iter = ctx.train_iter

        # best episode return so far
        if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value:
            save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict())
            self.max_eval_value = ctx.eval_value

        # finish
        if task.finish and self.save_finish:
            save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict())