AnnaMats's picture
Second Push
05c9ac2
from typing import cast
from mlagents.torch_utils import torch, nn, default_device
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil
from mlagents_envs.timers import timed
from typing import List, Dict, Tuple, Optional, Union, Any
from mlagents.trainers.torch_entities.networks import ValueNetwork, Actor
from mlagents_envs.base_env import ActionSpec, ObservationSpec
from mlagents.trainers.torch_entities.agent_action import AgentAction
from mlagents.trainers.torch_entities.utils import ModelUtils
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.settings import TrainerSettings, OffPolicyHyperparamSettings
from mlagents.trainers.settings import ScheduleType, NetworkSettings
from mlagents.trainers.torch_entities.networks import Critic
import numpy as np
import attr
# TODO: fix saving to onnx
@attr.s(auto_attribs=True)
class DQNSettings(OffPolicyHyperparamSettings):
gamma: float = 0.99
exploration_schedule: ScheduleType = ScheduleType.LINEAR
exploration_initial_eps: float = 0.1
exploration_final_eps: float = 0.05
target_update_interval: int = 10000
tau: float = 0.005
steps_per_update: float = 1
save_replay_buffer: bool = False
reward_signal_steps_per_update: float = attr.ib()
@reward_signal_steps_per_update.default
def _reward_signal_steps_per_update_default(self):
return self.steps_per_update
class DQNOptimizer(TorchOptimizer):
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
super().__init__(policy, trainer_settings)
# initialize hyper parameters
params = list(self.policy.actor.parameters())
self.optimizer = torch.optim.Adam(
params, lr=self.trainer_settings.hyperparameters.learning_rate
)
self.stream_names = list(self.reward_signals.keys())
self.gammas = [_val.gamma for _val in trainer_settings.reward_signals.values()]
self.use_dones_in_backup = {
name: int(not self.reward_signals[name].ignore_done)
for name in self.stream_names
}
self.hyperparameters: DQNSettings = cast(
DQNSettings, trainer_settings.hyperparameters
)
self.tau = self.hyperparameters.tau
self.decay_learning_rate = ModelUtils.DecayedValue(
self.hyperparameters.learning_rate_schedule,
self.hyperparameters.learning_rate,
1e-10,
self.trainer_settings.max_steps,
)
self.decay_exploration_rate = ModelUtils.DecayedValue(
self.hyperparameters.exploration_schedule,
self.hyperparameters.exploration_initial_eps,
self.hyperparameters.exploration_final_eps,
20000,
)
# initialize Target Q_network
self.q_net_target = QNetwork(
stream_names=self.reward_signals.keys(),
observation_specs=policy.behavior_spec.observation_specs,
network_settings=policy.network_settings,
action_spec=policy.behavior_spec.action_spec,
)
ModelUtils.soft_update(self.policy.actor, self.q_net_target, 1.0)
self.q_net_target.to(default_device())
@property
def critic(self):
return self.q_net_target
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Performs update on model.
:param batch: Batch of experiences.
:param num_sequences: Number of sequences to process.
:return: Results of update.
"""
# Get decayed parameters
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
exp_rate = self.decay_exploration_rate.get_value(self.policy.get_current_step())
self.policy.actor.exploration_rate = exp_rate
rewards = {}
for name in self.reward_signals:
rewards[name] = ModelUtils.list_to_tensor(
batch[RewardSignalUtil.rewards_key(name)]
)
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)
# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
next_obs = ObsUtil.from_buffer_next(batch, n_obs)
# Convert to tensors
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]
actions = AgentAction.from_buffer(batch)
dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE])
current_q_values, _ = self.policy.actor.critic_pass(
current_obs, sequence_length=self.policy.sequence_length
)
qloss = []
with torch.no_grad():
greedy_actions = self.policy.actor.get_greedy_action(current_q_values)
next_q_values_list, _ = self.q_net_target.critic_pass(
next_obs, sequence_length=self.policy.sequence_length
)
for name_i, name in enumerate(rewards.keys()):
with torch.no_grad():
next_q_values = torch.gather(
next_q_values_list[name], dim=1, index=greedy_actions
).squeeze()
target_q_values = rewards[name] + (
(1.0 - self.use_dones_in_backup[name] * dones)
* self.gammas[name_i]
* next_q_values
)
target_q_values = target_q_values.reshape(-1, 1)
curr_q = torch.gather(
current_q_values[name], dim=1, index=actions.discrete_tensor
)
qloss.append(torch.nn.functional.smooth_l1_loss(curr_q, target_q_values))
loss = torch.mean(torch.stack(qloss))
ModelUtils.update_learning_rate(self.optimizer, decay_lr)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
ModelUtils.soft_update(self.policy.actor, self.q_net_target, self.tau)
update_stats = {
"Losses/Value Loss": loss.item(),
"Policy/Learning Rate": decay_lr,
"Policy/epsilon": exp_rate,
}
for reward_provider in self.reward_signals.values():
update_stats.update(reward_provider.update(batch))
return update_stats
def get_modules(self):
modules = {
"Optimizer:value_optimizer": self.optimizer,
"Optimizer:critic": self.critic,
}
for reward_provider in self.reward_signals.values():
modules.update(reward_provider.get_modules())
return modules
class QNetwork(nn.Module, Actor, Critic):
MODEL_EXPORT_VERSION = 3
def __init__(
self,
stream_names: List[str],
observation_specs: List[ObservationSpec],
network_settings: NetworkSettings,
action_spec: ActionSpec,
exploration_initial_eps: float = 1.0,
):
self.exploration_rate = exploration_initial_eps
nn.Module.__init__(self)
output_act_size = max(sum(action_spec.discrete_branches), 1)
self.network_body = ValueNetwork(
stream_names,
observation_specs,
network_settings,
outputs_per_stream=output_act_size,
)
# extra tensors for exporting to ONNX
self.action_spec = action_spec
self.version_number = torch.nn.Parameter(
torch.Tensor([self.MODEL_EXPORT_VERSION]), requires_grad=False
)
self.is_continuous_int_deprecated = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False
)
self.continuous_act_size_vector = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False
)
self.discrete_act_size_vector = torch.nn.Parameter(
torch.Tensor([self.action_spec.discrete_branches]), requires_grad=False
)
self.act_size_vector_deprecated = torch.nn.Parameter(
torch.Tensor(
[
self.action_spec.continuous_size
+ sum(self.action_spec.discrete_branches)
]
),
requires_grad=False,
)
self.memory_size_vector = torch.nn.Parameter(
torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False
)
def update_normalization(self, buffer: AgentBuffer) -> None:
self.network_body.update_normalization(buffer)
def critic_pass(
self,
inputs: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
value_outputs, critic_mem_out = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
)
return value_outputs, critic_mem_out
@property
def memory_size(self) -> int:
return self.network_body.memory_size
def forward(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[Union[int, torch.Tensor], ...]:
out_vals, memories = self.critic_pass(inputs, memories, sequence_length)
# fixme random action tensor
export_out = [self.version_number, self.memory_size_vector]
disc_action_out = self.get_greedy_action(out_vals)
deterministic_disc_action_out = self.get_random_action(out_vals)
export_out += [
disc_action_out,
self.discrete_act_size_vector,
deterministic_disc_action_out,
]
return tuple(export_out)
def get_random_action(self, inputs) -> torch.Tensor:
action_out = torch.randint(
0, self.action_spec.discrete_branches[0], (len(inputs), 1)
)
return action_out
@staticmethod
def get_greedy_action(q_values) -> torch.Tensor:
all_q = torch.cat([val.unsqueeze(0) for val in q_values.values()])
return torch.argmax(all_q.sum(dim=0), dim=1, keepdim=True)
def get_action_and_stats(
self,
inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
deterministic=False,
) -> Tuple[AgentAction, Dict[str, Any], torch.Tensor]:
run_out = {}
if not deterministic and np.random.rand() < self.exploration_rate:
action_out = self.get_random_action(inputs)
action_out = AgentAction(None, [action_out])
run_out["env_action"] = action_out.to_action_tuple()
else:
out_vals, _ = self.critic_pass(inputs, memories, sequence_length)
action_out = self.get_greedy_action(out_vals)
action_out = AgentAction(None, [action_out])
run_out["env_action"] = action_out.to_action_tuple()
return action_out, run_out, torch.Tensor([])