|
from typing import Dict, Optional, Tuple, List |
|
from mlagents.torch_utils import torch |
|
import numpy as np |
|
from collections import defaultdict |
|
|
|
from mlagents.trainers.buffer import AgentBuffer, AgentBufferField |
|
from mlagents.trainers.trajectory import ObsUtil |
|
from mlagents.trainers.torch_entities.components.bc.module import BCModule |
|
from mlagents.trainers.torch_entities.components.reward_providers import ( |
|
create_reward_provider, |
|
) |
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
from mlagents.trainers.optimizer import Optimizer |
|
from mlagents.trainers.settings import ( |
|
TrainerSettings, |
|
RewardSignalSettings, |
|
RewardSignalType, |
|
) |
|
from mlagents.trainers.torch_entities.utils import ModelUtils |
|
|
|
|
|
class TorchOptimizer(Optimizer): |
|
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|
super().__init__() |
|
self.policy = policy |
|
self.trainer_settings = trainer_settings |
|
self.update_dict: Dict[str, torch.Tensor] = {} |
|
self.value_heads: Dict[str, torch.Tensor] = {} |
|
self.memory_in: torch.Tensor = None |
|
self.memory_out: torch.Tensor = None |
|
self.m_size: int = 0 |
|
self.global_step = torch.tensor(0) |
|
self.bc_module: Optional[BCModule] = None |
|
self.create_reward_signals(trainer_settings.reward_signals) |
|
self.critic_memory_dict: Dict[str, torch.Tensor] = {} |
|
if trainer_settings.behavioral_cloning is not None: |
|
self.bc_module = BCModule( |
|
self.policy, |
|
trainer_settings.behavioral_cloning, |
|
policy_learning_rate=trainer_settings.hyperparameters.learning_rate, |
|
default_batch_size=trainer_settings.hyperparameters.batch_size, |
|
default_num_epoch=3, |
|
) |
|
|
|
@property |
|
def critic(self): |
|
raise NotImplementedError |
|
|
|
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|
pass |
|
|
|
def create_reward_signals( |
|
self, reward_signal_configs: Dict[RewardSignalType, RewardSignalSettings] |
|
) -> None: |
|
""" |
|
Create reward signals |
|
:param reward_signal_configs: Reward signal config. |
|
""" |
|
for reward_signal, settings in reward_signal_configs.items(): |
|
|
|
self.reward_signals[reward_signal.value] = create_reward_provider( |
|
reward_signal, self.policy.behavior_spec, settings |
|
) |
|
|
|
def _evaluate_by_sequence( |
|
self, tensor_obs: List[torch.Tensor], initial_memory: torch.Tensor |
|
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]: |
|
""" |
|
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the |
|
intermediate memories for the critic. |
|
:param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's |
|
observations for this trajectory. |
|
:param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e. |
|
what is returned as the output of a MemoryModules. |
|
:return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial |
|
memories to be used during value function update, and the final memory at the end of the trajectory. |
|
""" |
|
num_experiences = tensor_obs[0].shape[0] |
|
all_next_memories = AgentBufferField() |
|
|
|
|
|
|
|
|
|
|
|
leftover_seq_len = num_experiences % self.policy.sequence_length |
|
|
|
all_values: Dict[str, List[np.ndarray]] = defaultdict(list) |
|
_mem = initial_memory |
|
|
|
|
|
for seq_num in range(num_experiences // self.policy.sequence_length): |
|
seq_obs = [] |
|
for _ in range(self.policy.sequence_length): |
|
all_next_memories.append(ModelUtils.to_numpy(_mem.squeeze())) |
|
start = seq_num * self.policy.sequence_length |
|
end = (seq_num + 1) * self.policy.sequence_length |
|
|
|
for _obs in tensor_obs: |
|
seq_obs.append(_obs[start:end]) |
|
values, _mem = self.critic.critic_pass( |
|
seq_obs, _mem, sequence_length=self.policy.sequence_length |
|
) |
|
for signal_name, _val in values.items(): |
|
all_values[signal_name].append(_val) |
|
|
|
|
|
|
|
seq_obs = [] |
|
|
|
if leftover_seq_len > 0: |
|
for _obs in tensor_obs: |
|
last_seq_obs = _obs[-leftover_seq_len:] |
|
seq_obs.append(last_seq_obs) |
|
|
|
|
|
|
|
for _ in range(leftover_seq_len): |
|
all_next_memories.append(ModelUtils.to_numpy(_mem.squeeze())) |
|
|
|
last_values, _mem = self.critic.critic_pass( |
|
seq_obs, _mem, sequence_length=leftover_seq_len |
|
) |
|
for signal_name, _val in last_values.items(): |
|
all_values[signal_name].append(_val) |
|
|
|
|
|
all_value_tensors = { |
|
signal_name: torch.cat(value_list, dim=0) |
|
for signal_name, value_list in all_values.items() |
|
} |
|
next_mem = _mem |
|
return all_value_tensors, all_next_memories, next_mem |
|
|
|
def update_reward_signals(self, batch: AgentBuffer) -> Dict[str, float]: |
|
update_stats: Dict[str, float] = {} |
|
for reward_provider in self.reward_signals.values(): |
|
update_stats.update(reward_provider.update(batch)) |
|
return update_stats |
|
|
|
def get_trajectory_value_estimates( |
|
self, |
|
batch: AgentBuffer, |
|
next_obs: List[np.ndarray], |
|
done: bool, |
|
agent_id: str = "", |
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]: |
|
""" |
|
Get value estimates and memories for a trajectory, in batch form. |
|
:param batch: An AgentBuffer that consists of a trajectory. |
|
:param next_obs: the next observation (after the trajectory). Used for boostrapping |
|
if this is not a termiinal trajectory. |
|
:param done: Set true if this is a terminal trajectory. |
|
:param agent_id: Agent ID of the agent that this trajectory belongs to. |
|
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)], |
|
the final value estimate as a Dict of [name, float], and optionally (if using memories) |
|
an AgentBufferField of initial critic memories to be used during update. |
|
""" |
|
n_obs = len(self.policy.behavior_spec.observation_specs) |
|
|
|
if agent_id in self.critic_memory_dict: |
|
memory = self.critic_memory_dict[agent_id] |
|
else: |
|
memory = ( |
|
torch.zeros((1, 1, self.critic.memory_size)) |
|
if self.policy.use_recurrent |
|
else None |
|
) |
|
|
|
|
|
current_obs = [ |
|
ModelUtils.list_to_tensor(obs) for obs in ObsUtil.from_buffer(batch, n_obs) |
|
] |
|
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] |
|
|
|
next_obs = [obs.unsqueeze(0) for obs in next_obs] |
|
|
|
|
|
all_next_memories: Optional[AgentBufferField] = None |
|
|
|
|
|
with torch.no_grad(): |
|
if self.policy.use_recurrent: |
|
( |
|
value_estimates, |
|
all_next_memories, |
|
next_memory, |
|
) = self._evaluate_by_sequence(current_obs, memory) |
|
else: |
|
value_estimates, next_memory = self.critic.critic_pass( |
|
current_obs, memory, sequence_length=batch.num_experiences |
|
) |
|
|
|
|
|
self.critic_memory_dict[agent_id] = next_memory |
|
|
|
next_value_estimate, _ = self.critic.critic_pass( |
|
next_obs, next_memory, sequence_length=1 |
|
) |
|
|
|
for name, estimate in value_estimates.items(): |
|
value_estimates[name] = ModelUtils.to_numpy(estimate) |
|
next_value_estimate[name] = ModelUtils.to_numpy(next_value_estimate[name]) |
|
|
|
if done: |
|
for k in next_value_estimate: |
|
if not self.reward_signals[k].ignore_done: |
|
next_value_estimate[k] = 0.0 |
|
if agent_id in self.critic_memory_dict: |
|
self.critic_memory_dict.pop(agent_id) |
|
return value_estimates, next_value_estimate, all_next_memories |
|
|