|
import random |
|
from typing import Callable, Union, List |
|
|
|
from ding.data.buffer import BufferedData |
|
from ding.utils import fastcopy |
|
|
|
|
|
def padding(policy="random"): |
|
""" |
|
Overview: |
|
Fill the nested buffer list to the same size as the largest list. |
|
The default policy `random` will randomly select data from each group |
|
and fill it into the current group list. |
|
Arguments: |
|
- policy (:obj:`str`): Padding policy, supports `random`, `none`. |
|
""" |
|
|
|
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: |
|
sampled_data = chain(*args, **kwargs) |
|
if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData): |
|
return sampled_data |
|
max_len = len(max(sampled_data, key=len)) |
|
for i, grouped_data in enumerate(sampled_data): |
|
group_len = len(grouped_data) |
|
if group_len == max_len: |
|
continue |
|
for _ in range(max_len - group_len): |
|
if policy == "random": |
|
sampled_data[i].append(fastcopy.copy(random.choice(grouped_data))) |
|
elif policy == "none": |
|
sampled_data[i].append(BufferedData(data=None, index=None, meta=None)) |
|
|
|
return sampled_data |
|
|
|
def _padding(action: str, chain: Callable, *args, **kwargs): |
|
if action == "sample": |
|
return sample(chain, *args, **kwargs) |
|
return chain(*args, **kwargs) |
|
|
|
return _padding |
|
|