File size: 1,648 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import random
from typing import Callable, List
from ding.data.buffer.buffer import BufferedData


def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable:
    """
    Overview:
        The middleware is designed to process the data in each group after sampling from the buffer.
    Arguments:
        - size_in_group (:obj:`int`): Sample size in each group.
        - ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true.
        - max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true.
    """

    def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]:
        if not kwargs.get("groupby"):
            raise Exception("Group sample must be used when the `groupby` parameter is specified.")
        sampled_data = chain(*args, **kwargs)
        for i, grouped_data in enumerate(sampled_data):
            if ordered_in_group:
                if max_use_in_group:
                    end = max(0, len(grouped_data) - size_in_group) + 1
                else:
                    end = len(grouped_data)
                start_idx = random.choice(range(end))
                sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group]
            else:
                sampled_data[i] = random.sample(grouped_data, k=size_in_group)
        return sampled_data

    def _group_sample(action: str, chain: Callable, *args, **kwargs):
        if action == "sample":
            return sample(chain, *args, **kwargs)
        return chain(*args, **kwargs)

    return _group_sample