|
import random |
|
from typing import Callable, List |
|
from ding.data.buffer.buffer import BufferedData |
|
|
|
|
|
def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable: |
|
""" |
|
Overview: |
|
The middleware is designed to process the data in each group after sampling from the buffer. |
|
Arguments: |
|
- size_in_group (:obj:`int`): Sample size in each group. |
|
- ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true. |
|
- max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true. |
|
""" |
|
|
|
def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]: |
|
if not kwargs.get("groupby"): |
|
raise Exception("Group sample must be used when the `groupby` parameter is specified.") |
|
sampled_data = chain(*args, **kwargs) |
|
for i, grouped_data in enumerate(sampled_data): |
|
if ordered_in_group: |
|
if max_use_in_group: |
|
end = max(0, len(grouped_data) - size_in_group) + 1 |
|
else: |
|
end = len(grouped_data) |
|
start_idx = random.choice(range(end)) |
|
sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group] |
|
else: |
|
sampled_data[i] = random.sample(grouped_data, k=size_in_group) |
|
return sampled_data |
|
|
|
def _group_sample(action: str, chain: Callable, *args, **kwargs): |
|
if action == "sample": |
|
return sample(chain, *args, **kwargs) |
|
return chain(*args, **kwargs) |
|
|
|
return _group_sample |
|
|