File size: 4,879 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import Dict, Any, Callable
from collections import namedtuple
from easydict import EasyDict
import gym
import torch

from ding.torch_utils import to_device


class PolicyFactory:
    """
    Overview:
        Policy factory class, used to generate different policies for general purpose. Such as random action policy, \
        which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0.
    Interfaces:
        ``get_random_policy``
    """

    @staticmethod
    def get_random_policy(
            policy: 'Policy.collect_mode',  # noqa
            action_space: 'gym.spaces.Space' = None,  # noqa
            forward_fn: Callable = None,
    ) -> 'Policy.collect_mode':  # noqa
        """
        Overview:
            According to the given action space, define the forward function of the random policy, then pack it with \
            other interfaces of the given policy, and return the final collect mode interfaces of policy.
        Arguments:
            - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
            - action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style.
            - forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \
                and pass it to this function, note you should set ``action_space`` to ``None`` in this case.
        Returns:
            - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
        """
        assert not (action_space is None and forward_fn is None)
        random_collect_function = namedtuple(
            'random_collect_function', [
                'forward',
                'process_transition',
                'get_train_sample',
                'reset',
                'get_attribute',
            ]
        )

        def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:

            actions = {}
            for env_id in data:
                if not isinstance(action_space, list):
                    if isinstance(action_space, gym.spaces.Discrete):
                        action = torch.LongTensor([action_space.sample()])
                    elif isinstance(action_space, gym.spaces.MultiDiscrete):
                        action = [torch.LongTensor([v]) for v in action_space.sample()]
                    else:
                        action = torch.as_tensor(action_space.sample())
                    actions[env_id] = {'action': action}
                elif 'global_state' in data[env_id].keys():
                    # for smac
                    logit = torch.ones_like(data[env_id]['action_mask'])
                    logit[data[env_id]['action_mask'] == 0.0] = -1e8
                    dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit))
                    actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)}
                else:
                    # for gfootball
                    actions[env_id] = {
                        'action': torch.as_tensor([action_space_agent.sample() for action_space_agent in action_space]),
                        'logit': torch.ones([len(action_space), action_space[0].n])
                    }
            return actions

        def reset(*args, **kwargs) -> None:
            pass

        if action_space is None:
            return random_collect_function(
                forward_fn, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute
            )
        elif forward_fn is None:
            return random_collect_function(
                forward, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute
            )


def get_random_policy(
        cfg: EasyDict,
        policy: 'Policy.collect_mode',  # noqa
        env: 'BaseEnvManager'  # noqa
) -> 'Policy.collect_mode':  # noqa
    """
    Overview:
        The entry function to get the corresponding random policy. If a policy needs special data items in a \
        transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy.
    Arguments:
        - cfg (:obj:`EasyDict`): The EasyDict-type dict configuration.
        - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
        - env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \
            action generation.
    Returns:
        - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
    """
    if cfg.policy.get('transition_with_policy_data', False):
        return policy
    else:
        action_space = env.action_space
        return PolicyFactory.get_random_policy(policy, action_space=action_space)