zjowowen's picture
init space
079c32c
from collections.abc import Sequence, Mapping
from typing import List, Dict, Union, Any
import torch
import treetensor.torch as ttorch
import re
import collections.abc as container_abcs
from ding.compatibility import torch_ge_131
int_classes = int
string_classes = (str, bytes)
np_str_obj_array_pattern = re.compile(r'[SaUO]')
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)
def ttorch_collate(x, json: bool = False, cat_1dim: bool = True):
"""
Overview:
Collates a list of tensors or nested dictionaries of tensors into a single tensor or nested \
dictionary of tensors.
Arguments:
- x : The input list of tensors or nested dictionaries of tensors.
- json (:obj:`bool`): If True, converts the output to JSON format. Defaults to False.
- cat_1dim (:obj:`bool`): If True, concatenates tensors with shape (B, 1) along the last dimension. \
Defaults to True.
Returns:
The collated output tensor or nested dictionary of tensors.
Examples:
>>> # case 1: Collate a list of tensors
>>> tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
>>> collated = ttorch_collate(tensors)
collated = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> # case 2: Collate a nested dictionary of tensors
>>> nested_dict = {
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
'c': torch.tensor([7, 8, 9])
}
>>> collated = ttorch_collate(nested_dict)
collated = {
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
'c': torch.tensor([7, 8, 9])
}
>>> # case 3: Collate a list of nested dictionaries of tensors
>>> nested_dicts = [
{'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6])},
{'a': torch.tensor([7, 8, 9]), 'b': torch.tensor([10, 11, 12])}
]
>>> collated = ttorch_collate(nested_dicts)
collated = {
'a': torch.tensor([[1, 2, 3], [7, 8, 9]]),
'b': torch.tensor([[4, 5, 6], [10, 11, 12]])
}
"""
def inplace_fn(t):
for k in t.keys():
if isinstance(t[k], torch.Tensor):
if len(t[k].shape) == 2 and t[k].shape[1] == 1: # reshape (B, 1) -> (B)
t[k] = t[k].squeeze(-1)
else:
inplace_fn(t[k])
x = ttorch.stack(x)
if cat_1dim:
inplace_fn(x)
if json:
x = x.json()
return x
def default_collate(batch: Sequence,
cat_1dim: bool = True,
ignore_prefix: list = ['collate_ignore']) -> Union[torch.Tensor, Mapping, Sequence]:
"""
Overview:
Put each data field into a tensor with outer dimension batch size.
Arguments:
- batch (:obj:`Sequence`): A data sequence, whose length is batch size, whose element is one piece of data.
- cat_1dim (:obj:`bool`): Whether to concatenate tensors with shape (B, 1) to (B), defaults to True.
- ignore_prefix (:obj:`list`): A list of prefixes to ignore when collating dictionaries, \
defaults to ['collate_ignore'].
Returns:
- ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): the collated data, with batch size into each data \
field. The return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].
Example:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> default_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> default_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = default_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
"""
if isinstance(batch, ttorch.Tensor):
return batch.json()
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch_ge_131() and torch.utils.data.get_worker_info() is not None:
# If we're in a background process, directly concatenate into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
if elem.shape == (1, ) and cat_1dim:
# reshape (B, 1) -> (B)
return torch.cat(batch, 0, out=out)
# return torch.stack(batch, 0, out=out)
else:
return torch.stack(batch, 0, out=out)
elif isinstance(elem, ttorch.Tensor):
return ttorch_collate(batch, json=True, cat_1dim=cat_1dim)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch], cat_1dim=cat_1dim)
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float32)
elif isinstance(elem, int_classes):
dtype = torch.bool if isinstance(elem, bool) else torch.int64
return torch.tensor(batch, dtype=dtype)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
ret = {}
for key in elem:
if any([key.startswith(t) for t in ignore_prefix]):
ret[key] = [d[key] for d in batch]
else:
ret[key] = default_collate([d[key] for d in batch], cat_1dim=cat_1dim)
return ret
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples, cat_1dim=cat_1dim) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples, cat_1dim=cat_1dim) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
def timestep_collate(batch: List[Dict[str, Any]]) -> Dict[str, Union[torch.Tensor, list]]:
"""
Overview:
Collates a batch of timestepped data fields into tensors with the outer dimension being the batch size. \
Each timestepped data field is represented as a tensor with shape [T, B, any_dims], where T is the length \
of the sequence, B is the batch size, and any_dims represents the shape of the tensor at each timestep.
Arguments:
- batch(:obj:`List[Dict[str, Any]]`): A list of dictionaries with length B, where each dictionary represents \
a timestepped data field. Each dictionary contains a key-value pair, where the key is the name of the \
data field and the value is a sequence of torch.Tensor objects with any shape.
Returns:
- ret(:obj:`Dict[str, Union[torch.Tensor, list]]`): The collated data, with the timestep and batch size \
incorporated into each data field. The shape of each data field is [T, B, dim1, dim2, ...].
Examples:
>>> batch = [
{'data0': [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]},
{'data1': [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]}
]
>>> collated_data = timestep_collate(batch)
>>> print(collated_data['data'].shape)
torch.Size([2, 2, 3])
"""
def stack(data):
if isinstance(data, container_abcs.Mapping):
return {k: stack(data[k]) for k in data}
elif isinstance(data, container_abcs.Sequence) and isinstance(data[0], torch.Tensor):
return torch.stack(data)
else:
return data
elem = batch[0]
assert isinstance(elem, (container_abcs.Mapping, list)), type(elem)
if isinstance(batch[0], list): # new pipeline + treetensor
prev_state = [[b[i].get('prev_state') for b in batch] for i in range(len(batch[0]))]
batch_data = ttorch.stack([ttorch_collate(b) for b in batch]) # (B, T, *)
del batch_data.prev_state
batch_data = batch_data.transpose(1, 0)
batch_data.prev_state = prev_state
else:
prev_state = [b.pop('prev_state') for b in batch]
batch_data = default_collate(batch) # -> {some_key: T lists}, each list is [B, some_dim]
batch_data = stack(batch_data) # -> {some_key: [T, B, some_dim]}
transformed_prev_state = list(zip(*prev_state))
batch_data['prev_state'] = transformed_prev_state
# append back prev_state, avoiding multi batch share the same data bug
for i in range(len(batch)):
batch[i]['prev_state'] = prev_state[i]
return batch_data
def diff_shape_collate(batch: Sequence) -> Union[torch.Tensor, Mapping, Sequence]:
"""
Overview:
Collates a batch of data with different shapes.
This function is similar to `default_collate`, but it allows tensors in the batch to have `None` values, \
which is common in StarCraft observations.
Arguments:
- batch (:obj:`Sequence`): A sequence of data, where each element is a piece of data.
Returns:
- ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): The collated data, with the batch size applied \
to each data field. The return type depends on the original element type and can be a torch.Tensor, \
Mapping, or Sequence.
Examples:
>>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
>>> a = [torch.zeros(2,3) for _ in range(4)]
>>> diff_shape_collate(a).shape
torch.Size([4, 2, 3])
>>>
>>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
>>> a = [[0 for __ in range(3)] for _ in range(4)]
>>> diff_shape_collate(a)
[tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
>>>
>>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
>>> # a dict whose values are tensors with shape :math:`(B, m, n)`
>>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
>>> print(a[0][2].shape, a[0][3].shape)
torch.Size([2, 3]) torch.Size([3, 4])
>>> b = diff_shape_collate(a)
>>> print(b[2].shape, b[3].shape)
torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
"""
elem = batch[0]
elem_type = type(elem)
if any([isinstance(elem, type(None)) for elem in batch]):
return batch
elif isinstance(elem, torch.Tensor):
shapes = [e.shape for e in batch]
if len(set(shapes)) != 1:
return batch
else:
return torch.stack(batch, 0)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray':
return diff_shape_collate([torch.as_tensor(b) for b in batch]) # todo
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float32)
elif isinstance(elem, int_classes):
dtype = torch.bool if isinstance(elem, bool) else torch.int64
return torch.tensor(batch, dtype=dtype)
elif isinstance(elem, Mapping):
return {key: diff_shape_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(diff_shape_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, Sequence):
transposed = zip(*batch)
return [diff_shape_collate(samples) for samples in transposed]
raise TypeError('not support element type: {}'.format(elem_type))
def default_decollate(
batch: Union[torch.Tensor, Sequence, Mapping],
ignore: List[str] = ['prev_state', 'prev_actor_state', 'prev_critic_state']
) -> List[Any]:
"""
Overview:
Drag out batch_size collated data's batch size to decollate it, which is the reverse operation of \
``default_collate``.
Arguments:
- batch (:obj:`Union[torch.Tensor, Sequence, Mapping]`): The collated data batch. It can be a tensor, \
sequence, or mapping.
- ignore(:obj:`List[str]`): A list of names to be ignored. Only applicable if the input ``batch`` is a \
dictionary. If a key is in this list, its value will remain the same without decollation. Defaults to \
['prev_state', 'prev_actor_state', 'prev_critic_state'].
Returns:
- ret (:obj:`List[Any]`): A list with B elements, where B is the batch size.
Examples:
>>> batch = {
'a': [
[1, 2, 3],
[4, 5, 6]
],
'b': [
[7, 8, 9],
[10, 11, 12]
]}
>>> default_decollate(batch)
{
0: {'a': [1, 2, 3], 'b': [7, 8, 9]},
1: {'a': [4, 5, 6], 'b': [10, 11, 12]},
}
"""
if isinstance(batch, torch.Tensor):
batch = torch.split(batch, 1, dim=0)
# Squeeze if the original batch's shape is like (B, dim1, dim2, ...);
# otherwise, directly return the list.
if len(batch[0].shape) > 1:
batch = [elem.squeeze(0) for elem in batch]
return list(batch)
elif isinstance(batch, Sequence):
return list(zip(*[default_decollate(e) for e in batch]))
elif isinstance(batch, Mapping):
tmp = {k: v if k in ignore else default_decollate(v) for k, v in batch.items()}
B = len(list(tmp.values())[0])
return [{k: tmp[k][i] for k in tmp.keys()} for i in range(B)]
elif isinstance(batch, torch.distributions.Distribution): # For compatibility
return [None for _ in range(batch.batch_shape[0])]
raise TypeError("Not supported batch type: {}".format(type(batch)))