File size: 1,064 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Callable, Any, List, Union
from ding.data.buffer import BufferedData
from ding.utils import fastcopy


def clone_object():
    """
    Overview:
        This middleware freezes the objects saved in memory buffer and return copies during sampling,
        try this middleware when you need to keep the object unchanged in buffer, and modify\
        the object after sampling it (usually in multiple threads)
    """

    def push(chain: Callable, data: Any, *args, **kwargs) -> BufferedData:
        data = fastcopy.copy(data)
        return chain(data, *args, **kwargs)

    def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
        data = chain(*args, **kwargs)
        return fastcopy.copy(data)

    def _clone_object(action: str, chain: Callable, *args, **kwargs):
        if action == "push":
            return push(chain, *args, **kwargs)
        elif action == "sample":
            return sample(chain, *args, **kwargs)
        return chain(*args, **kwargs)

    return _clone_object