AnnaMats's picture
Second Push
05c9ac2
from typing import Dict, cast
import attr
from mlagents.torch_utils import torch, default_device
from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil
from mlagents_envs.timers import timed
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.settings import (
TrainerSettings,
OnPolicyHyperparamSettings,
ScheduleType,
)
from mlagents.trainers.torch_entities.networks import ValueNetwork
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.exception import TrainerConfigError
@attr.s(auto_attribs=True)
class A2CSettings(OnPolicyHyperparamSettings):
beta: float = 5.0e-3
lambd: float = 0.95
num_epoch: int = attr.ib(default=1) # A2C does just one pass
shared_critic: bool = False
@num_epoch.validator
def _check_num_epoch_one(self, attribute, value):
if value != 1:
raise TrainerConfigError("A2C requires num_epoch = 1")
learning_rate_schedule: ScheduleType = ScheduleType.LINEAR
beta_schedule: ScheduleType = ScheduleType.LINEAR
class A2COptimizer(TorchOptimizer):
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
"""
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy.
The A2C optimizer has a value estimator and a loss function.
:param policy: A TorchPolicy object that will be updated by this A2C Optimizer.
:param trainer_params: Trainer parameters dictionary that specifies the
properties of the trainer.
"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
super().__init__(policy, trainer_settings)
self.hyperparameters: A2CSettings = cast(
A2CSettings, trainer_settings.hyperparameters
)
params = list(self.policy.actor.parameters())
if self.hyperparameters.shared_critic:
self._critic = policy.actor
else:
self._critic = ValueNetwork(
list(self.reward_signals.keys()),
policy.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
)
self._critic.to(default_device())
params += list(self._critic.parameters())
self.decay_learning_rate = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.learning_rate,
1e-10,
self.trainer_settings.max_steps,
)
self.decay_beta = ModelUtils.DecayedValue(
self.hyperparameters.beta_schedule,
self.hyperparameters.beta,
1e-10,
self.trainer_settings.max_steps,
)
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate
)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
self.stream_names = list(self.reward_signals.keys())
@property
def critic(self):
return self._critic
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Performs update on model.
:param batch: Batch of experiences.
:param num_sequences: Number of sequences to process.
:return: Results of update.
"""
# Get decayed parameters
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
returns = {}
for name in self.reward_signals:
returns[name] = ModelUtils.list_to_tensor(
batch[RewardSignalUtil.returns_key(name)]
)
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
actions = AgentAction.from_buffer(batch)
memories = [
ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i])
for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
]
if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
# Get value memories
value_memories = [
ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
for i in range(
0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length
)
]
if len(value_memories) > 0:
value_memories = torch.stack(value_memories).unsqueeze(0)
run_out = self.policy.actor.get_stats(
current_obs,
masks=act_masks,
actions=actions,
memories=memories,
sequence_length=self.policy.sequence_length,
)
log_probs = run_out["log_probs"]
entropy = run_out["entropy"]
values, _ = self.critic.critic_pass(
current_obs,
memories=value_memories,
sequence_length=self.policy.sequence_length,
)
log_probs = log_probs.flatten()
value_loss_per_head = []
for name, head in values.items():
returns_tensor = returns[name]
be = (returns_tensor - head) ** 2
value_loss_per_head.append(be)
value_loss = torch.mean(torch.stack(value_loss_per_head))
advantages = ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES])
policy_loss = -1 * torch.mean(torch.sum(log_probs, dim=1) * advantages)
loss = policy_loss + 0.5 * value_loss - decay_bet * torch.mean(entropy)
# Set optimizer learning rate
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
update_stats = {
# NOTE: abs() is not technically correct, but matches the behavior in TensorFlow.
# TODO: After PyTorch is default, change to something more correct.
"Losses/Policy Loss": torch.abs(policy_loss).item(),
"Losses/Value Loss": value_loss.item(),
"Policy/Learning Rate": decay_lr,
"Policy/Beta": decay_bet,
}
return update_stats
def get_modules(self):
modules = {
"Optimizer:value_optimizer": self.optimizer,
"Optimizer:critic": self._critic,
}
for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
return modules