|
from typing import Dict, cast |
|
import attr |
|
|
|
from mlagents.torch_utils import torch, default_device |
|
|
|
from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil |
|
|
|
from mlagents_envs.timers import timed |
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer |
|
from mlagents.trainers.settings import ( |
|
TrainerSettings, |
|
OnPolicyHyperparamSettings, |
|
ScheduleType, |
|
) |
|
from mlagents.trainers.torch_entities.networks import ValueNetwork |
|
from mlagents.trainers.torch_entities.agent_action import AgentAction |
|
from mlagents.trainers.torch_entities.utils import ModelUtils |
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
from mlagents.trainers.exception import TrainerConfigError |
|
|
|
|
|
@attr.s(auto_attribs=True) |
|
class A2CSettings(OnPolicyHyperparamSettings): |
|
beta: float = 5.0e-3 |
|
lambd: float = 0.95 |
|
num_epoch: int = attr.ib(default=1) |
|
shared_critic: bool = False |
|
|
|
@num_epoch.validator |
|
def _check_num_epoch_one(self, attribute, value): |
|
if value != 1: |
|
raise TrainerConfigError("A2C requires num_epoch = 1") |
|
|
|
learning_rate_schedule: ScheduleType = ScheduleType.LINEAR |
|
beta_schedule: ScheduleType = ScheduleType.LINEAR |
|
|
|
|
|
class A2COptimizer(TorchOptimizer): |
|
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|
""" |
|
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. |
|
The A2C optimizer has a value estimator and a loss function. |
|
:param policy: A TorchPolicy object that will be updated by this A2C Optimizer. |
|
:param trainer_params: Trainer parameters dictionary that specifies the |
|
properties of the trainer. |
|
""" |
|
|
|
|
|
super().__init__(policy, trainer_settings) |
|
self.hyperparameters: A2CSettings = cast( |
|
A2CSettings, trainer_settings.hyperparameters |
|
) |
|
|
|
params = list(self.policy.actor.parameters()) |
|
if self.hyperparameters.shared_critic: |
|
self._critic = policy.actor |
|
else: |
|
|
|
self._critic = ValueNetwork( |
|
list(self.reward_signals.keys()), |
|
policy.behavior_spec.observation_specs, |
|
network_settings=trainer_settings.network_settings, |
|
) |
|
self._critic.to(default_device()) |
|
params += list(self._critic.parameters()) |
|
|
|
self.decay_learning_rate = ModelUtils.DecayedValue( |
|
self.hyperparameters.learning_rate_schedule, |
|
self.hyperparameters.learning_rate, |
|
1e-10, |
|
self.trainer_settings.max_steps, |
|
) |
|
|
|
self.decay_beta = ModelUtils.DecayedValue( |
|
self.hyperparameters.beta_schedule, |
|
self.hyperparameters.beta, |
|
1e-10, |
|
self.trainer_settings.max_steps, |
|
) |
|
|
|
self.optimizer = torch.optim.Adam( |
|
params, lr=self.trainer_settings.hyperparameters.learning_rate |
|
) |
|
self.stats_name_to_update_name = { |
|
"Losses/Value Loss": "value_loss", |
|
"Losses/Policy Loss": "policy_loss", |
|
} |
|
|
|
self.stream_names = list(self.reward_signals.keys()) |
|
|
|
@property |
|
def critic(self): |
|
return self._critic |
|
|
|
@timed |
|
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|
""" |
|
Performs update on model. |
|
:param batch: Batch of experiences. |
|
:param num_sequences: Number of sequences to process. |
|
:return: Results of update. |
|
""" |
|
|
|
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) |
|
decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) |
|
returns = {} |
|
for name in self.reward_signals: |
|
returns[name] = ModelUtils.list_to_tensor( |
|
batch[RewardSignalUtil.returns_key(name)] |
|
) |
|
|
|
n_obs = len(self.policy.behavior_spec.observation_specs) |
|
current_obs = ObsUtil.from_buffer(batch, n_obs) |
|
|
|
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] |
|
|
|
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) |
|
actions = AgentAction.from_buffer(batch) |
|
|
|
memories = [ |
|
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) |
|
for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) |
|
] |
|
if len(memories) > 0: |
|
memories = torch.stack(memories).unsqueeze(0) |
|
|
|
|
|
value_memories = [ |
|
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) |
|
for i in range( |
|
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length |
|
) |
|
] |
|
if len(value_memories) > 0: |
|
value_memories = torch.stack(value_memories).unsqueeze(0) |
|
|
|
run_out = self.policy.actor.get_stats( |
|
current_obs, |
|
masks=act_masks, |
|
actions=actions, |
|
memories=memories, |
|
sequence_length=self.policy.sequence_length, |
|
) |
|
|
|
log_probs = run_out["log_probs"] |
|
entropy = run_out["entropy"] |
|
|
|
values, _ = self.critic.critic_pass( |
|
current_obs, |
|
memories=value_memories, |
|
sequence_length=self.policy.sequence_length, |
|
) |
|
log_probs = log_probs.flatten() |
|
|
|
value_loss_per_head = [] |
|
for name, head in values.items(): |
|
returns_tensor = returns[name] |
|
be = (returns_tensor - head) ** 2 |
|
value_loss_per_head.append(be) |
|
value_loss = torch.mean(torch.stack(value_loss_per_head)) |
|
|
|
advantages = ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]) |
|
policy_loss = -1 * torch.mean(torch.sum(log_probs, dim=1) * advantages) |
|
|
|
loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy) |
|
|
|
|
|
ModelUtils.update_learning_rate(self.optimizer, decay_lr) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
update_stats = { |
|
|
|
|
|
"Losses/Policy Loss": torch.abs(policy_loss).item(), |
|
"Losses/Value Loss": value_loss.item(), |
|
"Policy/Learning Rate": decay_lr, |
|
"Policy/Beta": decay_bet, |
|
} |
|
|
|
return update_stats |
|
|
|
def get_modules(self): |
|
modules = { |
|
"Optimizer:value_optimizer": self.optimizer, |
|
"Optimizer:critic": self._critic, |
|
} |
|
for reward_provider in self.reward_signals.values(): |
|
modules.update(reward_provider.get_modules()) |
|
return modules |
|
|