|
from collections import defaultdict |
|
from collections.abc import MutableMapping |
|
import enum |
|
import itertools |
|
from typing import BinaryIO, DefaultDict, List, Tuple, Union, Optional |
|
|
|
import numpy as np |
|
import h5py |
|
|
|
from mlagents_envs.exception import UnityException |
|
|
|
|
|
|
|
BufferEntry = Union[np.ndarray, List[np.ndarray]] |
|
|
|
|
|
class BufferException(UnityException): |
|
""" |
|
Related to errors with the Buffer. |
|
""" |
|
|
|
pass |
|
|
|
|
|
class BufferKey(enum.Enum): |
|
ACTION_MASK = "action_mask" |
|
CONTINUOUS_ACTION = "continuous_action" |
|
NEXT_CONT_ACTION = "next_continuous_action" |
|
CONTINUOUS_LOG_PROBS = "continuous_log_probs" |
|
DISCRETE_ACTION = "discrete_action" |
|
NEXT_DISC_ACTION = "next_discrete_action" |
|
DISCRETE_LOG_PROBS = "discrete_log_probs" |
|
DONE = "done" |
|
ENVIRONMENT_REWARDS = "environment_rewards" |
|
MASKS = "masks" |
|
MEMORY = "memory" |
|
CRITIC_MEMORY = "critic_memory" |
|
BASELINE_MEMORY = "poca_baseline_memory" |
|
PREV_ACTION = "prev_action" |
|
|
|
ADVANTAGES = "advantages" |
|
DISCOUNTED_RETURNS = "discounted_returns" |
|
|
|
GROUP_DONES = "group_dones" |
|
GROUPMATE_REWARDS = "groupmate_reward" |
|
GROUP_REWARD = "group_reward" |
|
GROUP_CONTINUOUS_ACTION = "group_continuous_action" |
|
GROUP_DISCRETE_ACTION = "group_discrete_aaction" |
|
GROUP_NEXT_CONT_ACTION = "group_next_cont_action" |
|
GROUP_NEXT_DISC_ACTION = "group_next_disc_action" |
|
|
|
|
|
class ObservationKeyPrefix(enum.Enum): |
|
OBSERVATION = "obs" |
|
NEXT_OBSERVATION = "next_obs" |
|
|
|
GROUP_OBSERVATION = "group_obs" |
|
NEXT_GROUP_OBSERVATION = "next_group_obs" |
|
|
|
|
|
class RewardSignalKeyPrefix(enum.Enum): |
|
|
|
REWARDS = "rewards" |
|
VALUE_ESTIMATES = "value_estimates" |
|
RETURNS = "returns" |
|
ADVANTAGE = "advantage" |
|
BASELINES = "baselines" |
|
|
|
|
|
AgentBufferKey = Union[ |
|
BufferKey, Tuple[ObservationKeyPrefix, int], Tuple[RewardSignalKeyPrefix, str] |
|
] |
|
|
|
|
|
class RewardSignalUtil: |
|
@staticmethod |
|
def rewards_key(name: str) -> AgentBufferKey: |
|
return RewardSignalKeyPrefix.REWARDS, name |
|
|
|
@staticmethod |
|
def value_estimates_key(name: str) -> AgentBufferKey: |
|
return RewardSignalKeyPrefix.RETURNS, name |
|
|
|
@staticmethod |
|
def returns_key(name: str) -> AgentBufferKey: |
|
return RewardSignalKeyPrefix.RETURNS, name |
|
|
|
@staticmethod |
|
def advantage_key(name: str) -> AgentBufferKey: |
|
return RewardSignalKeyPrefix.ADVANTAGE, name |
|
|
|
@staticmethod |
|
def baseline_estimates_key(name: str) -> AgentBufferKey: |
|
return RewardSignalKeyPrefix.BASELINES, name |
|
|
|
|
|
class AgentBufferField(list): |
|
""" |
|
AgentBufferField is a list of numpy arrays, or List[np.ndarray] for group entries. |
|
When an agent collects a field, you can add it to its AgentBufferField with the append method. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
self.padding_value = 0 |
|
super().__init__(*args, **kwargs) |
|
|
|
def __str__(self) -> str: |
|
return f"AgentBufferField: {super().__str__()}" |
|
|
|
def __getitem__(self, index): |
|
return_data = super().__getitem__(index) |
|
if isinstance(return_data, list): |
|
return AgentBufferField(return_data) |
|
else: |
|
return return_data |
|
|
|
@property |
|
def contains_lists(self) -> bool: |
|
""" |
|
Checks whether this AgentBufferField contains List[np.ndarray]. |
|
""" |
|
return len(self) > 0 and isinstance(self[0], list) |
|
|
|
def append(self, element: BufferEntry, padding_value: float = 0.0) -> None: |
|
""" |
|
Adds an element to this list. Also lets you change the padding |
|
type, so that it can be set on append (e.g. action_masks should |
|
be padded with 1.) |
|
:param element: The element to append to the list. |
|
:param padding_value: The value used to pad when get_batch is called. |
|
""" |
|
super().append(element) |
|
self.padding_value = padding_value |
|
|
|
def set(self, data: List[BufferEntry]) -> None: |
|
""" |
|
Sets the list of BufferEntry to the input data |
|
:param data: The BufferEntry list to be set. |
|
""" |
|
self[:] = data |
|
|
|
def get_batch( |
|
self, |
|
batch_size: int = None, |
|
training_length: Optional[int] = 1, |
|
sequential: bool = True, |
|
) -> List[BufferEntry]: |
|
""" |
|
Retrieve the last batch_size elements of length training_length |
|
from the list of np.array |
|
:param batch_size: The number of elements to retrieve. If None: |
|
All elements will be retrieved. |
|
:param training_length: The length of the sequence to be retrieved. If |
|
None: only takes one element. |
|
:param sequential: If true and training_length is not None: the elements |
|
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and |
|
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives |
|
[[a,b],[b,c],[c,d],[d,e]] |
|
""" |
|
if training_length is None: |
|
training_length = 1 |
|
if sequential: |
|
|
|
leftover = len(self) % training_length |
|
|
|
if batch_size is None: |
|
|
|
batch_size = len(self) // training_length + 1 * (leftover != 0) |
|
|
|
|
|
if batch_size > (len(self) // training_length + 1 * (leftover != 0)): |
|
raise BufferException( |
|
"The batch size and training length requested for get_batch where" |
|
" too large given the current number of data points." |
|
) |
|
if batch_size * training_length > len(self): |
|
if self.contains_lists: |
|
padding = [] |
|
else: |
|
|
|
padding = np.array(self[-1], dtype=np.float32) * self.padding_value |
|
return self[:] + [padding] * (training_length - leftover) |
|
|
|
else: |
|
return self[len(self) - batch_size * training_length :] |
|
else: |
|
|
|
if batch_size is None: |
|
|
|
batch_size = len(self) - training_length + 1 |
|
|
|
|
|
if (len(self) - training_length + 1) < batch_size: |
|
raise BufferException( |
|
"The batch size and training length requested for get_batch where" |
|
" too large given the current number of data points." |
|
) |
|
tmp_list: List[np.ndarray] = [] |
|
for end in range(len(self) - batch_size + 1, len(self) + 1): |
|
tmp_list += self[end - training_length : end] |
|
return tmp_list |
|
|
|
def reset_field(self) -> None: |
|
""" |
|
Resets the AgentBufferField |
|
""" |
|
self[:] = [] |
|
|
|
def padded_to_batch( |
|
self, pad_value: np.float = 0, dtype: np.dtype = np.float32 |
|
) -> Union[np.ndarray, List[np.ndarray]]: |
|
""" |
|
Converts this AgentBufferField (which is a List[BufferEntry]) into a numpy array |
|
with first dimension equal to the length of this AgentBufferField. If this AgentBufferField |
|
contains a List[List[BufferEntry]] (i.e., in the case of group observations), return a List |
|
containing numpy arrays or tensors, of length equal to the maximum length of an entry. Missing |
|
For entries with less than that length, the array will be padded with pad_value. |
|
:param pad_value: Value to pad List AgentBufferFields, when there are less than the maximum |
|
number of agents present. |
|
:param dtype: Dtype of output numpy array. |
|
:return: Numpy array or List of numpy arrays representing this AgentBufferField, where the first |
|
dimension is equal to the length of the AgentBufferField. |
|
""" |
|
if len(self) > 0 and not isinstance(self[0], list): |
|
return np.asanyarray(self, dtype=dtype) |
|
|
|
shape = None |
|
for _entry in self: |
|
|
|
|
|
if _entry: |
|
shape = _entry[0].shape |
|
break |
|
|
|
if shape is None: |
|
return [] |
|
|
|
|
|
new_list = list( |
|
map( |
|
lambda x: np.asanyarray(x, dtype=dtype), |
|
itertools.zip_longest(*self, fillvalue=np.full(shape, pad_value)), |
|
) |
|
) |
|
return new_list |
|
|
|
def to_ndarray(self): |
|
""" |
|
Returns the AgentBufferField which is a list of numpy ndarrays (or List[np.ndarray]) as an ndarray. |
|
""" |
|
return np.array(self) |
|
|
|
|
|
class AgentBuffer(MutableMapping): |
|
""" |
|
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer. |
|
The keys correspond to the name of the field. Example: state, action |
|
""" |
|
|
|
|
|
|
|
CHECK_KEY_TYPES_AT_RUNTIME = False |
|
|
|
def __init__(self): |
|
self.last_brain_info = None |
|
self.last_take_action_outputs = None |
|
self._fields: DefaultDict[AgentBufferKey, AgentBufferField] = defaultdict( |
|
AgentBufferField |
|
) |
|
|
|
def __str__(self): |
|
return ", ".join([f"'{k}' : {str(self[k])}" for k in self._fields.keys()]) |
|
|
|
def reset_agent(self) -> None: |
|
""" |
|
Resets the AgentBuffer |
|
""" |
|
for f in self._fields.values(): |
|
f.reset_field() |
|
self.last_brain_info = None |
|
self.last_take_action_outputs = None |
|
|
|
@staticmethod |
|
def _check_key(key): |
|
if isinstance(key, BufferKey): |
|
return |
|
if isinstance(key, tuple): |
|
key0, key1 = key |
|
if isinstance(key0, ObservationKeyPrefix): |
|
if isinstance(key1, int): |
|
return |
|
raise KeyError(f"{key} has type ({type(key0)}, {type(key1)})") |
|
if isinstance(key0, RewardSignalKeyPrefix): |
|
if isinstance(key1, str): |
|
return |
|
raise KeyError(f"{key} has type ({type(key0)}, {type(key1)})") |
|
raise KeyError(f"{key} is a {type(key)}") |
|
|
|
@staticmethod |
|
def _encode_key(key: AgentBufferKey) -> str: |
|
""" |
|
Convert the key to a string representation so that it can be used for serialization. |
|
""" |
|
if isinstance(key, BufferKey): |
|
return key.value |
|
prefix, suffix = key |
|
return f"{prefix.value}:{suffix}" |
|
|
|
@staticmethod |
|
def _decode_key(encoded_key: str) -> AgentBufferKey: |
|
""" |
|
Convert the string representation back to a key after serialization. |
|
""" |
|
|
|
try: |
|
return BufferKey(encoded_key) |
|
except ValueError: |
|
pass |
|
|
|
|
|
prefix_str, _, suffix_str = encoded_key.partition(":") |
|
|
|
|
|
try: |
|
return ObservationKeyPrefix(prefix_str), int(suffix_str) |
|
except ValueError: |
|
pass |
|
|
|
|
|
try: |
|
return RewardSignalKeyPrefix(prefix_str), suffix_str |
|
except ValueError: |
|
raise ValueError(f"Unable to convert {encoded_key} to an AgentBufferKey") |
|
|
|
def __getitem__(self, key: AgentBufferKey) -> AgentBufferField: |
|
if self.CHECK_KEY_TYPES_AT_RUNTIME: |
|
self._check_key(key) |
|
return self._fields[key] |
|
|
|
def __setitem__(self, key: AgentBufferKey, value: AgentBufferField) -> None: |
|
if self.CHECK_KEY_TYPES_AT_RUNTIME: |
|
self._check_key(key) |
|
self._fields[key] = value |
|
|
|
def __delitem__(self, key: AgentBufferKey) -> None: |
|
if self.CHECK_KEY_TYPES_AT_RUNTIME: |
|
self._check_key(key) |
|
self._fields.__delitem__(key) |
|
|
|
def __iter__(self): |
|
return self._fields.__iter__() |
|
|
|
def __len__(self) -> int: |
|
return self._fields.__len__() |
|
|
|
def __contains__(self, key): |
|
if self.CHECK_KEY_TYPES_AT_RUNTIME: |
|
self._check_key(key) |
|
return self._fields.__contains__(key) |
|
|
|
def check_length(self, key_list: List[AgentBufferKey]) -> bool: |
|
""" |
|
Some methods will require that some fields have the same length. |
|
check_length will return true if the fields in key_list |
|
have the same length. |
|
:param key_list: The fields which length will be compared |
|
""" |
|
if self.CHECK_KEY_TYPES_AT_RUNTIME: |
|
for k in key_list: |
|
self._check_key(k) |
|
|
|
if len(key_list) < 2: |
|
return True |
|
length = None |
|
for key in key_list: |
|
if key not in self._fields: |
|
return False |
|
if (length is not None) and (length != len(self[key])): |
|
return False |
|
length = len(self[key]) |
|
return True |
|
|
|
def shuffle( |
|
self, sequence_length: int, key_list: List[AgentBufferKey] = None |
|
) -> None: |
|
""" |
|
Shuffles the fields in key_list in a consistent way: The reordering will |
|
be the same across fields. |
|
:param key_list: The fields that must be shuffled. |
|
""" |
|
if key_list is None: |
|
key_list = list(self._fields.keys()) |
|
if not self.check_length(key_list): |
|
raise BufferException( |
|
"Unable to shuffle if the fields are not of same length" |
|
) |
|
s = np.arange(len(self[key_list[0]]) // sequence_length) |
|
np.random.shuffle(s) |
|
for key in key_list: |
|
buffer_field = self[key] |
|
tmp: List[np.ndarray] = [] |
|
for i in s: |
|
tmp += buffer_field[i * sequence_length : (i + 1) * sequence_length] |
|
buffer_field.set(tmp) |
|
|
|
def make_mini_batch(self, start: int, end: int) -> "AgentBuffer": |
|
""" |
|
Creates a mini-batch from buffer. |
|
:param start: Starting index of buffer. |
|
:param end: Ending index of buffer. |
|
:return: Dict of mini batch. |
|
""" |
|
mini_batch = AgentBuffer() |
|
for key, field in self._fields.items(): |
|
|
|
mini_batch[key] = field[start:end] |
|
return mini_batch |
|
|
|
def sample_mini_batch( |
|
self, batch_size: int, sequence_length: int = 1 |
|
) -> "AgentBuffer": |
|
""" |
|
Creates a mini-batch from a random start and end. |
|
:param batch_size: number of elements to withdraw. |
|
:param sequence_length: Length of sequences to sample. |
|
Number of sequences to sample will be batch_size/sequence_length. |
|
""" |
|
num_seq_to_sample = batch_size // sequence_length |
|
mini_batch = AgentBuffer() |
|
buff_len = self.num_experiences |
|
num_sequences_in_buffer = buff_len // sequence_length |
|
start_idxes = ( |
|
np.random.randint(num_sequences_in_buffer, size=num_seq_to_sample) |
|
* sequence_length |
|
) |
|
for key in self: |
|
buffer_field = self[key] |
|
mb_list = (buffer_field[i : i + sequence_length] for i in start_idxes) |
|
|
|
|
|
mini_batch[key].set(list(itertools.chain.from_iterable(mb_list))) |
|
return mini_batch |
|
|
|
def save_to_file(self, file_object: BinaryIO) -> None: |
|
""" |
|
Saves the AgentBuffer to a file-like object. |
|
""" |
|
with h5py.File(file_object, "w") as write_file: |
|
for key, data in self.items(): |
|
write_file.create_dataset( |
|
self._encode_key(key), data=data, dtype="f", compression="gzip" |
|
) |
|
|
|
def load_from_file(self, file_object: BinaryIO) -> None: |
|
""" |
|
Loads the AgentBuffer from a file-like object. |
|
""" |
|
with h5py.File(file_object, "r") as read_file: |
|
for key in list(read_file.keys()): |
|
decoded_key = self._decode_key(key) |
|
self[decoded_key] = AgentBufferField() |
|
|
|
self[decoded_key].extend(read_file[key][()]) |
|
|
|
def truncate(self, max_length: int, sequence_length: int = 1) -> None: |
|
""" |
|
Truncates the buffer to a certain length. |
|
|
|
This can be slow for large buffers. We compensate by cutting further than we need to, so that |
|
we're not truncating at each update. Note that we must truncate an integer number of sequence_lengths |
|
param: max_length: The length at which to truncate the buffer. |
|
""" |
|
current_length = self.num_experiences |
|
|
|
max_length -= max_length % sequence_length |
|
if current_length > max_length: |
|
for _key in self.keys(): |
|
self[_key][:] = self[_key][current_length - max_length :] |
|
|
|
def resequence_and_append( |
|
self, |
|
target_buffer: "AgentBuffer", |
|
key_list: List[AgentBufferKey] = None, |
|
batch_size: int = None, |
|
training_length: int = None, |
|
) -> None: |
|
""" |
|
Takes in a batch size and training length (sequence length), and appends this AgentBuffer to target_buffer |
|
properly padded for LSTM use. Optionally, use key_list to restrict which fields are inserted into the new |
|
buffer. |
|
:param target_buffer: The buffer which to append the samples to. |
|
:param key_list: The fields that must be added. If None: all fields will be appended. |
|
:param batch_size: The number of elements that must be appended. If None: All of them will be. |
|
:param training_length: The length of the samples that must be appended. If None: only takes one element. |
|
""" |
|
if key_list is None: |
|
key_list = list(self.keys()) |
|
if not self.check_length(key_list): |
|
raise BufferException( |
|
f"The length of the fields {key_list} were not of same length" |
|
) |
|
for field_key in key_list: |
|
target_buffer[field_key].extend( |
|
self[field_key].get_batch( |
|
batch_size=batch_size, training_length=training_length |
|
) |
|
) |
|
|
|
@property |
|
def num_experiences(self) -> int: |
|
""" |
|
The number of agent experiences in the AgentBuffer, i.e. the length of the buffer. |
|
|
|
An experience consists of one element across all of the fields of this AgentBuffer. |
|
Note that these all have to be the same length, otherwise shuffle and append_to_update_buffer |
|
will fail. |
|
""" |
|
if self.values(): |
|
return len(next(iter(self.values()))) |
|
else: |
|
return 0 |
|
|