File size: 4,914 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from abc import ABC, abstractmethod
from typing import Dict
from easydict import EasyDict
from ditk import logging
import os
import copy
from typing import Any
from ding.utils import REWARD_MODEL_REGISTRY, import_module, save_file


class BaseRewardModel(ABC):
    """
    Overview:
        the base class of reward model
    Interface:
        ``default_config``, ``estimate``, ``train``, ``clear_data``, ``collect_data``, ``load_expert_date``
    """

    @classmethod
    def default_config(cls: type) -> EasyDict:
        cfg = EasyDict(copy.deepcopy(cls.config))
        cfg.cfg_type = cls.__name__ + 'Dict'
        return cfg

    @abstractmethod
    def estimate(self, data: list) -> Any:
        """
        Overview:
            estimate reward
        Arguments:
            - data (:obj:`List`): the list of data used for estimation
        Returns / Effects:
            - This can be a side effect function which updates the reward value
            - If this function returns, an example returned object can be reward (:obj:`Any`): the estimated reward
        """
        raise NotImplementedError()

    @abstractmethod
    def train(self, data) -> None:
        """
        Overview:
            Training the reward model
        Arguments:
            - data (:obj:`Any`): Data used for training
        Effects:
            - This is mostly a side effect function which updates the reward model
        """
        raise NotImplementedError()

    @abstractmethod
    def collect_data(self, data) -> None:
        """
        Overview:
            Collecting training data in designated formate or with designated transition.
        Arguments:
            - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc)
        Returns / Effects:
            - This can be a side effect function which updates the data attribute in ``self``
        """
        raise NotImplementedError()

    @abstractmethod
    def clear_data(self) -> None:
        """
        Overview:
            Clearing training data. \
            This can be a side effect function which clears the data attribute in ``self``
        """
        raise NotImplementedError()

    def load_expert_data(self, data) -> None:
        """
        Overview:
            Getting the expert data, usually used in inverse RL reward model
        Arguments:
            - data (:obj:`Any`): Expert data
        Effects:
            This is mostly a side effect function which updates the expert data attribute (e.g.  ``self.expert_data``)
        """
        pass

    def reward_deepcopy(self, train_data) -> Any:
        """
        Overview:
            this method deepcopy reward part in train_data, and other parts keep shallow copy
            to avoid the reward part of train_data in the replay buffer be incorrectly modified.
        Arguments:
            - train_data (:obj:`List`): the List of train data in which the reward part will be operated by deepcopy.
        """
        train_data_reward_deepcopy = [
            {k: copy.deepcopy(v) if k == 'reward' else v
             for k, v in sample.items()} for sample in train_data
        ]
        return train_data_reward_deepcopy

    def state_dict(self) -> Dict:
        # this method should be overrided by subclass.
        return {}

    def load_state_dict(self, _state_dict) -> None:
        # this method should be overrided by subclass.
        pass

    def save(self, path: str = None, name: str = 'best'):
        if path is None:
            path = self.cfg.exp_name
        path = os.path.join(path, 'reward_model', 'ckpt')
        if not os.path.exists(path):
            try:
                os.makedirs(path)
            except FileExistsError:
                pass
        path = os.path.join(path, 'ckpt_{}.pth.tar'.format(name))
        state_dict = self.state_dict()
        save_file(path, state_dict)
        logging.info('Saved reward model ckpt in {}'.format(path))


def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> BaseRewardModel:  # noqa
    """
    Overview:
        Reward Estimation Model.
    Arguments:
        - cfg (:obj:`Dict`): Training config
        - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
        - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
    Returns:
        - reward (:obj:`Any`): The reward model
    """
    cfg = copy.deepcopy(cfg)
    if 'import_names' in cfg:
        import_module(cfg.pop('import_names'))
    if hasattr(cfg, 'reward_model'):
        reward_model_type = cfg.reward_model.pop('type')
    else:
        reward_model_type = cfg.pop('type')
    return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger)


def get_reward_model_cls(cfg: EasyDict) -> type:
    import_module(cfg.get('import_names', []))
    return REWARD_MODEL_REGISTRY.get(cfg.type)