File size: 8,298 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from typing import List, Dict, Any
from easydict import EasyDict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Independent, Normal

from ding.utils import REWARD_MODEL_REGISTRY
from ding.utils.data import default_collate
from .base_reward_model import BaseRewardModel


class GuidedCostNN(nn.Module):

    def __init__(
        self,
        input_size,
        hidden_size=128,
        output_size=1,
    ):
        super(GuidedCostNN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
        )

    def forward(self, x):
        return self.net(x)


@REWARD_MODEL_REGISTRY.register('guided_cost')
class GuidedCostRewardModel(BaseRewardModel):
    """
    Overview:
        Policy class of Guided cost algorithm. (https://arxiv.org/pdf/1603.00448.pdf)
    Interface:
        ``estimate``, ``train``, ``collect_data``, ``clear_date``, \
        ``__init__``,  ``state_dict``, ``load_state_dict``, ``learn``\
        ``state_dict_reward_model``, ``load_state_dict_reward_model``
    Config:
        == ====================  ========   =============  ========================================  ================
        ID Symbol                Type       Default Value  Description                               Other(Shape)
        == ====================  ========   =============  ========================================  ================
        1  ``type``              str         guided_cost   | Reward model register name, refer        |
                                                           | to registry ``REWARD_MODEL_REGISTRY``    |
        2  | ``continuous``      bool        True          | Whether action is continuous             |
        3  | ``learning_rate``   float       0.001         | learning rate for optimizer              |
        4  | ``update_per_``     int         100           | Number of updates per collect            |
           | ``collect``                                   |                                          |
        5  | ``batch_size``      int         64            | Training batch size                      |
        6  | ``hidden_size``     int         128           | Linear model hidden size                 |
        7  | ``action_shape``    int         1             | Action space shape                       |
        8  | ``log_every_n``     int         50            | add loss to log every n iteration        |
           | ``_train``                                    |                                          |
        9  | ``store_model_``    int         100           | save model every n iteration             |
           | ``every_n_train``                                                                        |
        == ====================  ========   =============  ========================================  ================

    """

    config = dict(
        # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
        type='guided_cost',
        # (float) The step size of gradient descent.
        learning_rate=1e-3,
        # (int) Action space shape, such as 1.
        action_shape=1,
        # (bool) Whether action is continuous.
        continuous=True,
        # (int) How many samples in a training batch.
        batch_size=64,
        # (int) Linear model hidden size.
        hidden_size=128,
        # (int) How many updates(iterations) to train after collector's one collection.
        # Bigger "update_per_collect" means bigger off-policy.
        # collect data -> update policy-> collect data -> ...
        update_per_collect=100,
        # (int) Add loss to log every n iteration.
        log_every_n_train=50,
        # (int) Save model every n iteration.
        store_model_every_n_train=100,
    )

    def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None:  # noqa
        super(GuidedCostRewardModel, self).__init__()
        self.cfg = config
        self.action_shape = self.cfg.action_shape
        assert device == "cpu" or device.startswith("cuda")
        self.device = device
        self.tb_logger = tb_logger
        self.reward_model = GuidedCostNN(config.input_size, config.hidden_size)
        self.reward_model.to(self.device)
        self.opt = optim.Adam(self.reward_model.parameters(), lr=config.learning_rate)

    def train(self, expert_demo: torch.Tensor, samp: torch.Tensor, iter, step):
        device_0 = expert_demo[0]['obs'].device
        device_1 = samp[0]['obs'].device
        for i in range(len(expert_demo)):
            expert_demo[i]['prob'] = torch.FloatTensor([1]).to(device_0)
        if self.cfg.continuous:
            for i in range(len(samp)):
                (mu, sigma) = samp[i]['logit']
                dist = Independent(Normal(mu, sigma), 1)
                next_action = samp[i]['action']
                log_prob = dist.log_prob(next_action)
                samp[i]['prob'] = torch.exp(log_prob).unsqueeze(0).to(device_1)
        else:
            for i in range(len(samp)):
                probs = F.softmax(samp[i]['logit'], dim=-1)
                prob = probs[samp[i]['action']]
                samp[i]['prob'] = prob.to(device_1)
        # Mix the expert data and sample data to train the reward model.
        samp.extend(expert_demo)
        expert_demo = default_collate(expert_demo)
        samp = default_collate(samp)
        cost_demo = self.reward_model(
            torch.cat([expert_demo['obs'], expert_demo['action'].float().reshape(-1, self.action_shape)], dim=-1)
        )
        cost_samp = self.reward_model(
            torch.cat([samp['obs'], samp['action'].float().reshape(-1, self.action_shape)], dim=-1)
        )

        prob = samp['prob'].unsqueeze(-1)
        loss_IOC = torch.mean(cost_demo) + \
            torch.log(torch.mean(torch.exp(-cost_samp)/(prob+1e-7)))
        # UPDATING THE COST FUNCTION
        self.opt.zero_grad()
        loss_IOC.backward()
        self.opt.step()
        if iter % self.cfg.log_every_n_train == 0:
            self.tb_logger.add_scalar('reward_model/loss_iter', loss_IOC, iter)
            self.tb_logger.add_scalar('reward_model/loss_step', loss_IOC, step)

    def estimate(self, data: list) -> List[Dict]:
        # NOTE: this estimate method of gcl alg. is a little different from the one in other irl alg.,
        # because its deepcopy is operated before learner train loop.
        train_data_augmented = data
        for i in range(len(train_data_augmented)):
            with torch.no_grad():
                reward = self.reward_model(
                    torch.cat([train_data_augmented[i]['obs'], train_data_augmented[i]['action'].float()]).unsqueeze(0)
                ).squeeze(0)
                train_data_augmented[i]['reward'] = -reward

        return train_data_augmented

    def collect_data(self, data) -> None:
        """
        Overview:
            Collecting training data, not implemented if reward model (i.e. online_net) is only trained ones, \
                if online_net is trained continuously, there should be some implementations in collect_data method
        """
        # if online_net is trained continuously, there should be some implementations in collect_data method
        pass

    def clear_data(self):
        """
        Overview:
            Collecting clearing data, not implemented if reward model (i.e. online_net) is only trained ones, \
                if online_net is trained continuously, there should be some implementations in clear_data method
        """
        # if online_net is trained continuously, there should be some implementations in clear_data method
        pass

    def state_dict_reward_model(self) -> Dict[str, Any]:
        return {
            'model': self.reward_model.state_dict(),
            'optimizer': self.opt.state_dict(),
        }

    def load_state_dict_reward_model(self, state_dict: Dict[str, Any]) -> None:
        self.reward_model.load_state_dict(state_dict['model'])
        self.opt.load_state_dict(state_dict['optimizer'])