File size: 1,648 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import random
from typing import Callable, List
from ding.data.buffer.buffer import BufferedData
def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable:
"""
Overview:
The middleware is designed to process the data in each group after sampling from the buffer.
Arguments:
- size_in_group (:obj:`int`): Sample size in each group.
- ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true.
- max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true.
"""
def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]:
if not kwargs.get("groupby"):
raise Exception("Group sample must be used when the `groupby` parameter is specified.")
sampled_data = chain(*args, **kwargs)
for i, grouped_data in enumerate(sampled_data):
if ordered_in_group:
if max_use_in_group:
end = max(0, len(grouped_data) - size_in_group) + 1
else:
end = len(grouped_data)
start_idx = random.choice(range(end))
sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group]
else:
sampled_data[i] = random.sample(grouped_data, k=size_in_group)
return sampled_data
def _group_sample(action: str, chain: Callable, *args, **kwargs):
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _group_sample
|