File size: 7,469 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from typing import Dict
import numpy as np
from mlagents.torch_utils import torch

from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.demo_loader import demo_to_buffer
from mlagents.trainers.settings import BehavioralCloningSettings, ScheduleType
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.action_log_probs import ActionLogProbs
from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.buffer import AgentBuffer


class BCModule:
    def __init__(
        self,
        policy: TorchPolicy,
        settings: BehavioralCloningSettings,
        policy_learning_rate: float,
        default_batch_size: int,
        default_num_epoch: int,
    ):
        """
        A BC trainer that can be used inline with RL.
        :param policy: The policy of the learning model
        :param settings: The settings for BehavioralCloning including LR strength, batch_size,
        num_epochs, samples_per_update and LR annealing steps.
        :param policy_learning_rate: The initial Learning Rate of the policy. Used to set an appropriate learning rate
            for the pretrainer.
        """
        self.policy = policy
        self._anneal_steps = settings.steps
        self.current_lr = policy_learning_rate * settings.strength

        learning_rate_schedule: ScheduleType = (
            ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
        )
        self.decay_learning_rate = ModelUtils.DecayedValue(
            learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
        )
        params = self.policy.actor.parameters()
        self.optimizer = torch.optim.Adam(params, lr=self.current_lr)
        _, self.demonstration_buffer = demo_to_buffer(
            settings.demo_path, policy.sequence_length, policy.behavior_spec
        )
        self.batch_size = (
            settings.batch_size if settings.batch_size else default_batch_size
        )
        self.num_epoch = settings.num_epoch if settings.num_epoch else default_num_epoch
        self.n_sequences = max(
            min(self.batch_size, self.demonstration_buffer.num_experiences)
            // policy.sequence_length,
            1,
        )

        self.has_updated = False
        self.use_recurrent = self.policy.use_recurrent
        self.samples_per_update = settings.samples_per_update

    def update(self) -> Dict[str, np.ndarray]:
        """
        Updates model using buffer.
        :param max_batches: The maximum number of batches to use per update.
        :return: The loss of the update.
        """
        # Don't continue training if the learning rate has reached 0, to reduce training time.

        decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
        if self.current_lr <= 1e-10:  # Unlike in TF, this never actually reaches 0.
            return {"Losses/Pretraining Loss": 0}

        batch_losses = []
        possible_demo_batches = (
            self.demonstration_buffer.num_experiences // self.n_sequences
        )
        possible_batches = possible_demo_batches

        max_batches = self.samples_per_update // self.n_sequences

        n_epoch = self.num_epoch
        for _ in range(n_epoch):
            self.demonstration_buffer.shuffle(
                sequence_length=self.policy.sequence_length
            )
            if max_batches == 0:
                num_batches = possible_batches
            else:
                num_batches = min(possible_batches, max_batches)
            for i in range(num_batches // self.policy.sequence_length):
                demo_update_buffer = self.demonstration_buffer
                start = i * self.n_sequences * self.policy.sequence_length
                end = (i + 1) * self.n_sequences * self.policy.sequence_length
                mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)
                run_out = self._update_batch(mini_batch_demo, self.n_sequences)
                loss = run_out["loss"]
                batch_losses.append(loss)

        ModelUtils.update_learning_rate(self.optimizer, decay_lr)
        self.current_lr = decay_lr

        self.has_updated = True
        update_stats = {"Losses/Pretraining Loss": np.mean(batch_losses)}
        return update_stats

    def _behavioral_cloning_loss(
        self,
        selected_actions: AgentAction,
        log_probs: ActionLogProbs,
        expert_actions: torch.Tensor,
    ) -> torch.Tensor:
        bc_loss = 0
        if self.policy.behavior_spec.action_spec.continuous_size > 0:
            bc_loss += torch.nn.functional.mse_loss(
                selected_actions.continuous_tensor, expert_actions.continuous_tensor
            )
        if self.policy.behavior_spec.action_spec.discrete_size > 0:
            one_hot_expert_actions = ModelUtils.actions_to_onehot(
                expert_actions.discrete_tensor,
                self.policy.behavior_spec.action_spec.discrete_branches,
            )
            log_prob_branches = ModelUtils.break_into_branches(
                log_probs.all_discrete_tensor,
                self.policy.behavior_spec.action_spec.discrete_branches,
            )
            bc_loss += torch.mean(
                torch.stack(
                    [
                        torch.sum(
                            -torch.nn.functional.log_softmax(log_prob_branch, dim=1)
                            * expert_actions_branch,
                            dim=1,
                        )
                        for log_prob_branch, expert_actions_branch in zip(
                            log_prob_branches, one_hot_expert_actions
                        )
                    ]
                )
            )
        return bc_loss

    def _update_batch(
        self, mini_batch_demo: AgentBuffer, n_sequences: int
    ) -> Dict[str, float]:
        """
        Helper function for update_batch.
        """
        np_obs = ObsUtil.from_buffer(
            mini_batch_demo, len(self.policy.behavior_spec.observation_specs)
        )
        # Convert to tensors
        tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]
        act_masks = None
        expert_actions = AgentAction.from_buffer(mini_batch_demo)
        if self.policy.behavior_spec.action_spec.discrete_size > 0:

            act_masks = ModelUtils.list_to_tensor(
                np.ones(
                    (
                        self.n_sequences * self.policy.sequence_length,
                        sum(self.policy.behavior_spec.action_spec.discrete_branches),
                    ),
                    dtype=np.float32,
                )
            )

        memories = []
        if self.policy.use_recurrent:
            memories = torch.zeros(1, self.n_sequences, self.policy.m_size)

        selected_actions, run_out, _ = self.policy.actor.get_action_and_stats(
            tensor_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )
        log_probs = run_out["log_probs"]
        bc_loss = self._behavioral_cloning_loss(
            selected_actions, log_probs, expert_actions
        )
        self.optimizer.zero_grad()
        bc_loss.backward()

        self.optimizer.step()
        run_out = {"loss": bc_loss.item()}
        return run_out