|
import math |
|
from functools import partial, lru_cache |
|
from typing import Optional, Dict, Any |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from ding.compatibility import torch_ge_180 |
|
from ding.torch_utils import one_hot |
|
|
|
num_first_one_hot = partial(one_hot, num_first=True) |
|
|
|
|
|
def sqrt_one_hot(v: torch.Tensor, max_val: int) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Sqrt the input value ``v`` and transform it into one-hot. |
|
Arguments: |
|
- v (:obj:`torch.Tensor`): the value to be processed with `sqrt` and `one-hot` |
|
- max_val (:obj:`int`): the input ``v``'s estimated max value, used to calculate one-hot bit number. \ |
|
``v`` would be clamped by (0, max_val). |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): the value processed after `sqrt` and `one-hot` |
|
""" |
|
num = int(math.sqrt(max_val)) + 1 |
|
v = v.float() |
|
v = torch.floor(torch.sqrt(torch.clamp(v, 0, max_val))).long() |
|
return one_hot(v, num) |
|
|
|
|
|
def div_one_hot(v: torch.Tensor, max_val: int, ratio: int) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Divide the input value ``v`` by ``ratio`` and transform it into one-hot. |
|
Arguments: |
|
- v (:obj:`torch.Tensor`): the value to be processed with `divide` and `one-hot` |
|
- max_val (:obj:`int`): the input ``v``'s estimated max value, used to calculate one-hot bit number. \ |
|
``v`` would be clamped by (0, ``max_val``). |
|
- ratio (:obj:`int`): input ``v`` would be divided by ``ratio`` |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): the value processed after `divide` and `one-hot` |
|
""" |
|
num = int(max_val / ratio) + 1 |
|
v = v.float() |
|
v = torch.floor(torch.clamp(v, 0, max_val) / ratio).long() |
|
return one_hot(v, num) |
|
|
|
|
|
def div_func(inputs: torch.Tensor, other: float, unsqueeze_dim: int = 1): |
|
""" |
|
Overview: |
|
Divide ``inputs`` by ``other`` and unsqueeze if needed. |
|
Arguments: |
|
- inputs (:obj:`torch.Tensor`): the value to be unsqueezed and divided |
|
- other (:obj:`float`): input would be divided by ``other`` |
|
- unsqueeze_dim (:obj:`int`): the dim to implement unsqueeze |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): the value processed after `unsqueeze` and `divide` |
|
""" |
|
inputs = inputs.float() |
|
if unsqueeze_dim is not None: |
|
inputs = inputs.unsqueeze(unsqueeze_dim) |
|
return torch.div(inputs, other) |
|
|
|
|
|
def clip_one_hot(v: torch.Tensor, num: int) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Clamp the input ``v`` in (0, num-1) and make one-hot mapping. |
|
Arguments: |
|
- v (:obj:`torch.Tensor`): the value to be processed with `clamp` and `one-hot` |
|
- num (:obj:`int`): number of one-hot bits |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): the value processed after `clamp` and `one-hot` |
|
""" |
|
v = v.clamp(0, num - 1) |
|
return one_hot(v, num) |
|
|
|
|
|
def reorder_one_hot( |
|
v: torch.LongTensor, |
|
dictionary: Dict[int, int], |
|
num: int, |
|
transform: Optional[np.ndarray] = None |
|
) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Reorder each value in input ``v`` according to reorder dict ``dictionary``, then make one-hot mapping |
|
Arguments: |
|
- v (:obj:`torch.LongTensor`): the original value to be processed with `reorder` and `one-hot` |
|
- dictionary (:obj:`Dict[int, int]`): a reorder lookup dict, \ |
|
map original value to new reordered index starting from 0 |
|
- num (:obj:`int`): number of one-hot bits |
|
- transform (:obj:`int`): an array to firstly transform the original action to general action |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): one-hot data indicating reordered index |
|
""" |
|
assert (len(v.shape) == 1) |
|
assert (isinstance(v, torch.Tensor)) |
|
new_v = torch.zeros_like(v) |
|
for idx in range(v.shape[0]): |
|
if transform is None: |
|
val = v[idx].item() |
|
else: |
|
val = transform[v[idx].item()] |
|
new_v[idx] = dictionary[val] |
|
return one_hot(new_v, num) |
|
|
|
|
|
def reorder_one_hot_array( |
|
v: torch.LongTensor, array: np.ndarray, num: int, transform: Optional[np.ndarray] = None |
|
) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Reorder each value in input ``v`` according to reorder dict ``dictionary``, then make one-hot mapping. |
|
The difference between this function and ``reorder_one_hot`` is |
|
whether the type of reorder lookup data structure is `np.ndarray` or `dict`. |
|
Arguments: |
|
- v (:obj:`torch.LongTensor`): the value to be processed with `reorder` and `one-hot` |
|
- array (:obj:`np.ndarray`): a reorder lookup array, map original value to new reordered index starting from 0 |
|
- num (:obj:`int`): number of one-hot bits |
|
- transform (:obj:`np.ndarray`): an array to firstly transform the original action to general action |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): one-hot data indicating reordered index |
|
""" |
|
v = v.numpy() |
|
if transform is None: |
|
val = array[v] |
|
else: |
|
val = array[transform[v]] |
|
return one_hot(torch.LongTensor(val), num) |
|
|
|
|
|
def reorder_boolean_vector( |
|
v: torch.LongTensor, |
|
dictionary: Dict[int, int], |
|
num: int, |
|
transform: Optional[np.ndarray] = None |
|
) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Reorder each value in input ``v`` to new index according to reorder dict ``dictionary``, |
|
then set corresponding position in return tensor to 1. |
|
Arguments: |
|
- v (:obj:`torch.LongTensor`): the value to be processed with `reorder` |
|
- dictionary (:obj:`Dict[int, int]`): a reorder lookup dict, \ |
|
map original value to new reordered index starting from 0 |
|
- num (:obj:`int`): total number of items, should equals to max index + 1 |
|
- transform (:obj:`np.ndarray`): an array to firstly transform the original action to general action |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): boolean data containing only 0 and 1, \ |
|
indicating whether corresponding original value exists in input ``v`` |
|
""" |
|
ret = torch.zeros(num) |
|
for item in v: |
|
try: |
|
if transform is None: |
|
val = item.item() |
|
else: |
|
val = transform[item.item()] |
|
idx = dictionary[val] |
|
except KeyError as e: |
|
|
|
raise KeyError('{}_{}_'.format(num, e)) |
|
ret[idx] = 1 |
|
return ret |
|
|
|
|
|
@lru_cache(maxsize=32) |
|
def get_to_and(num_bits: int) -> np.ndarray: |
|
""" |
|
Overview: |
|
Get an np.ndarray with ``num_bits`` elements, each equals to :math:`2^n` (n decreases from num_bits-1 to 0). |
|
Used by ``batch_binary_encode`` to make bit-wise `and`. |
|
Arguments: |
|
- num_bits (:obj:`int`): length of the generating array |
|
Returns: |
|
- to_and (:obj:`np.ndarray`): an array with ``num_bits`` elements, \ |
|
each equals to :math:`2^n` (n decreases from num_bits-1 to 0) |
|
""" |
|
return 2 ** np.arange(num_bits - 1, -1, -1).reshape([1, num_bits]) |
|
|
|
|
|
def batch_binary_encode(x: torch.Tensor, bit_num: int) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Big endian binary encode ``x`` to float tensor |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): the value to be unsqueezed and divided |
|
- bit_num (:obj:`int`): number of bits, should satisfy :math:`2^{bit num} > max(x)` |
|
Example: |
|
>>> batch_binary_encode(torch.tensor([131,71]), 10) |
|
tensor([[0., 0., 1., 0., 0., 0., 0., 0., 1., 1.], |
|
[0., 0., 0., 1., 0., 0., 0., 1., 1., 1.]]) |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): the binary encoded tensor, containing only `0` and `1` |
|
""" |
|
x = x.numpy() |
|
xshape = list(x.shape) |
|
x = x.reshape([-1, 1]) |
|
to_and = get_to_and(bit_num) |
|
return torch.FloatTensor((x & to_and).astype(bool).astype(float).reshape(xshape + [bit_num])) |
|
|
|
|
|
def compute_denominator(x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Compute the denominator used in ``get_postion_vector``. \ |
|
Divide 1 at the last step, so you can use it as an multiplier. |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): Input tensor, which is generated from torch.arange(0, d_model). |
|
Returns: |
|
- ret (:obj:`torch.Tensor`): Denominator result tensor. |
|
""" |
|
if torch_ge_180(): |
|
x = torch.div(x, 2, rounding_mode='trunc') * 2 |
|
else: |
|
x = torch.div(x, 2) * 2 |
|
x = torch.div(x, 64.) |
|
x = torch.pow(10000., x) |
|
x = torch.div(1., x) |
|
return x |
|
|
|
|
|
def get_postion_vector(x: list) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Get position embedding used in `Transformer`, even and odd :math:`\alpha` are stored in ``POSITION_ARRAY`` |
|
Arguments: |
|
- x (:obj:`list`): original position index, whose length should be 32 |
|
Returns: |
|
- v (:obj:`torch.Tensor`): position embedding tensor in 64 dims |
|
""" |
|
|
|
POSITION_ARRAY = compute_denominator(torch.arange(0, 64, dtype=torch.float)) |
|
v = torch.zeros(64, dtype=torch.float) |
|
x = torch.FloatTensor(x) |
|
v[0::2] = torch.sin(x * POSITION_ARRAY[0::2]) |
|
v[1::2] = torch.cos(x * POSITION_ARRAY[1::2]) |
|
return v |
|
|
|
|
|
def affine_transform( |
|
data: Any, |
|
action_clip: Optional[bool] = True, |
|
alpha: Optional[float] = None, |
|
beta: Optional[float] = None, |
|
min_val: Optional[float] = None, |
|
max_val: Optional[float] = None |
|
) -> Any: |
|
""" |
|
Overview: |
|
do affine transform for data in range [-1, 1], :math:`\alpha \times data + \beta` |
|
Arguments: |
|
- data (:obj:`Any`): the input data |
|
- action_clip (:obj:`bool`): whether to do action clip operation ([-1, 1]) |
|
- alpha (:obj:`float`): affine transform weight |
|
- beta (:obj:`float`): affine transform bias |
|
- min_val (:obj:`float`): min value, if `min_val` and `max_val` are indicated, scale input data\ |
|
to [min_val, max_val] |
|
- max_val (:obj:`float`): max value |
|
Returns: |
|
- transformed_data (:obj:`Any`): affine transformed data |
|
""" |
|
if action_clip: |
|
data = np.clip(data, -1, 1) |
|
if min_val is not None: |
|
assert max_val is not None |
|
alpha = (max_val - min_val) / 2 |
|
beta = (max_val + min_val) / 2 |
|
assert alpha is not None |
|
beta = beta if beta is not None else 0. |
|
return data * alpha + beta |
|
|
|
|
|
def save_frames_as_gif(frames: list, path: str) -> None: |
|
""" |
|
Overview: |
|
save frames as gif to a specified path. |
|
Arguments: |
|
- frames (:obj:`List`): list of frames |
|
- path (:obj:`str`): the path to save gif |
|
""" |
|
try: |
|
import imageio |
|
except ImportError: |
|
from ditk import logging |
|
import sys |
|
logging.warning("Please install imageio first.") |
|
sys.exit(1) |
|
imageio.mimsave(path, frames, fps=20) |
|
|