File size: 5,953 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import importlib
from ditk import logging
from collections import OrderedDict
from functools import wraps
import ding
'''
Overview:
    `hpc_wrapper` is the wrapper for functions which are supported by hpc. If a function is wrapped by it, we will
    search for its hpc type and return the function implemented by hpc.
    We will use the following code as a sample to introduce `hpc_wrapper`:
    ```
    @hpc_wrapper(shape_fn=shape_fn_dntd, namedtuple_data=True, include_args=[0,1,2,3],
                 include_kwargs=['data', 'gamma', 'v_min', 'v_max'], is_cls_method=False)
    def dist_nstep_td_error(
            data: namedtuple,
            gamma: float,
            v_min: float,
            v_max: float,
            n_atom: int,
            nstep: int = 1,
    ) -> torch.Tensor:
    ...
    ```
Parameters:
    - shape_fn (:obj:`function`): a function which return the shape needed by hpc function. In fact, it returns
        all args that the hpc function needs.
    - nametuple_data (:obj:`bool`): If True, when hpc function is called, it will be called as hpc_function(*nametuple).
        If False, nametuple data will remain its `nametuple` type.
    - include_args (:obj:`list`): a list of index of the args need to be set in hpc function. As shown in the sample,
        include_args=[0,1,2,3], which means `data`, `gamma`, `v_min` and `v_max` will be set in hpc function.
    - include_kwargs (:obj:`list`): a list of key of the kwargs need to be set in hpc function. As shown in the sample,
        include_kwargs=['data', 'gamma', 'v_min', 'v_max'], which means `data`, `gamma`, `v_min` and `v_max` will be
        set in hpc function.
    - is_cls_method (:obj:`bool`): If True, it means the function we wrap is a method of a class. `self` will be put
        into args. We will get rid of `self` in args. Besides, we will use its classname as its fn_name.
        If False, it means the function is a simple method.
Q&A:
    - Q: Is `include_args` and `include_kwargs` need to be set at the same time?
    - A: Yes. `include_args` and `include_kwargs` can deal with all type of input, such as (data, gamma, v_min=v_min,
        v_max=v_max) and (data, gamma, v_min, v_max).
    - Q: What is `hpc_fns`?
    - A: Here we show a normal `hpc_fns`:
         ```
         hpc_fns = {
             'fn_name1': {
                 'runtime_name1': hpc_fn1,
                 'runtime_name2': hpc_fn2,
                 ...
             },
             ...
         }
         ```
         Besides, `per_fn_limit` means the max length of `hpc_fns[fn_name]`. When new function comes, the oldest
         function will be popped from `hpc_fns[fn_name]`.
'''

hpc_fns = {}
per_fn_limit = 3


def register_runtime_fn(fn_name, runtime_name, shape):
    fn_name_mapping = {
        'gae': ['hpc_rll.rl_utils.gae', 'GAE'],
        'dist_nstep_td_error': ['hpc_rll.rl_utils.td', 'DistNStepTD'],
        'LSTM': ['hpc_rll.torch_utils.network.rnn', 'LSTM'],
        'ppo_error': ['hpc_rll.rl_utils.ppo', 'PPO'],
        'q_nstep_td_error': ['hpc_rll.rl_utils.td', 'QNStepTD'],
        'q_nstep_td_error_with_rescale': ['hpc_rll.rl_utils.td', 'QNStepTDRescale'],
        'ScatterConnection': ['hpc_rll.torch_utils.network.scatter_connection', 'ScatterConnection'],
        'td_lambda_error': ['hpc_rll.rl_utils.td', 'TDLambda'],
        'upgo_loss': ['hpc_rll.rl_utils.upgo', 'UPGO'],
        'vtrace_error_discrete_action': ['hpc_rll.rl_utils.vtrace', 'VTrace'],
    }
    fn_str = fn_name_mapping[fn_name]
    cls = getattr(importlib.import_module(fn_str[0]), fn_str[1])
    hpc_fn = cls(*shape).cuda()
    if fn_name not in hpc_fns:
        hpc_fns[fn_name] = OrderedDict()
    hpc_fns[fn_name][runtime_name] = hpc_fn
    while len(hpc_fns[fn_name]) > per_fn_limit:
        hpc_fns[fn_name].popitem(last=False)
    # print(hpc_fns)
    return hpc_fn


def hpc_wrapper(shape_fn=None, namedtuple_data=False, include_args=[], include_kwargs=[], is_cls_method=False):

    def decorate(fn):

        @wraps(fn)
        def wrapper(*args, **kwargs):
            if ding.enable_hpc_rl:
                shape = shape_fn(args, kwargs)
                if is_cls_method:
                    fn_name = args[0].__class__.__name__
                else:
                    fn_name = fn.__name__
                runtime_name = '_'.join([fn_name] + [str(s) for s in shape])
                if fn_name not in hpc_fns or runtime_name not in hpc_fns[fn_name]:
                    hpc_fn = register_runtime_fn(fn_name, runtime_name, shape)
                else:
                    hpc_fn = hpc_fns[fn_name][runtime_name]
                if is_cls_method:
                    args = args[1:]
                clean_args = []
                for i in include_args:
                    if i < len(args):
                        clean_args.append(args[i])
                nouse_args = list(set(list(range(len(args)))).difference(set(include_args)))
                clean_kwargs = {}
                for k, v in kwargs.items():
                    if k in include_kwargs:
                        if k == 'lambda_':
                            k = 'lambda'
                        clean_kwargs[k] = v
                nouse_kwargs = list(set(kwargs.keys()).difference(set(include_kwargs)))
                if len(nouse_args) > 0 or len(nouse_kwargs) > 0:
                    logging.warn(
                        'in {}, index {} of args are dropped, and keys {} of kwargs are dropped.'.format(
                            runtime_name, nouse_args, nouse_kwargs
                        )
                    )
                if namedtuple_data:
                    data = args[0]  # args[0] is a namedtuple
                    return hpc_fn(*data, *clean_args[1:], **clean_kwargs)
                else:
                    return hpc_fn(*clean_args, **clean_kwargs)
            else:
                return fn(*args, **kwargs)

        return wrapper

    return decorate