File size: 5,303 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from typing import List, Any, Dict, Callable
import torch
import numpy as np
import treetensor.torch as ttorch
from ding.utils.data import default_collate
from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze


def default_preprocess_learn(
        data: List[Any],
        use_priority_IS_weight: bool = False,
        use_priority: bool = False,
        use_nstep: bool = False,
        ignore_done: bool = False,
) -> Dict[str, torch.Tensor]:
    """
    Overview:
        Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \
        ignore done, nstep and priority IS weight.
    Arguments:
        - data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor.
        - use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \
            will set the weight of each sample to the priority IS weight.
        - use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight.
        - use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward.
        - ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0.
    Returns:
        - data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \
            the following model forward and loss computation.
    """
    # data preprocess
    elem = data[0]
    if isinstance(elem['action'], (np.ndarray, torch.Tensor)) and elem['action'].dtype in [np.int64, torch.int64]:
        data = default_collate(data, cat_1dim=True)  # for discrete action
    else:
        data = default_collate(data, cat_1dim=False)  # for continuous action
    if 'value' in data and data['value'].dim() == 2 and data['value'].shape[1] == 1:
        data['value'] = data['value'].squeeze(-1)
    if 'adv' in data and data['adv'].dim() == 2 and data['adv'].shape[1] == 1:
        data['adv'] = data['adv'].squeeze(-1)

    if ignore_done:
        data['done'] = torch.zeros_like(data['done']).float()
    else:
        data['done'] = data['done'].float()

    if data['done'].dim() == 2 and data['done'].shape[1] == 1:
        data['done'] = data['done'].squeeze(-1)

    if use_priority_IS_weight:
        assert use_priority, "Use IS Weight correction, but Priority is not used."
    if use_priority and use_priority_IS_weight:
        if 'priority_IS' in data:
            data['weight'] = data['priority_IS']
        else:  # for compability
            data['weight'] = data['IS']
    else:
        data['weight'] = data.get('weight', None)
    if use_nstep:
        # reward reshaping for n-step
        reward = data['reward']
        if len(reward.shape) == 1:
            reward = reward.unsqueeze(1)
        # reward: (batch_size, nstep) -> (nstep, batch_size)
        data['reward'] = reward.permute(1, 0).contiguous()
    else:
        if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
            data['reward'] = data['reward'].squeeze(-1)

    return data


def single_env_forward_wrapper(forward_fn: Callable) -> Callable:
    """
    Overview:
        Wrap policy to support gym-style interaction between policy and single environment.
    Arguments:
        - forward_fn (:obj:`Callable`): The original forward function of policy.
    Returns:
        - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
    Examples:
        >>> env = gym.make('CartPole-v0')
        >>> policy = DQNPolicy(...)
        >>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward)
        >>> obs = env.reset()
        >>> action = forward_fn(obs)
        >>> next_obs, rew, done, info = env.step(action)

    """

    def _forward(obs):
        obs = {0: unsqueeze(to_tensor(obs))}
        action = forward_fn(obs)[0]['action']
        action = to_ndarray(squeeze(action))
        return action

    return _forward


def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable:
    """
    Overview:
        Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data.
    Arguments:
        - forward_fn (:obj:`Callable`): The original forward function of policy.
        - cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda.
    Returns:
        - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.

    Examples:
        >>> env = gym.make('CartPole-v0')
        >>> policy = PPOFPolicy(...)
        >>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval)
        >>> obs = env.reset()
        >>> action = forward_fn(obs)
        >>> next_obs, rew, done, info = env.step(action)
    """

    def _forward(obs):
        # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
        obs = ttorch.as_tensor(obs).unsqueeze(0)
        if cuda and torch.cuda.is_available():
            obs = obs.cuda()
        action = forward_fn(obs).action
        # squeeze means delete batch dim, i.e. (1, A) -> (A, )
        action = action.squeeze(0).cpu().numpy()
        return action

    return _forward