|
from typing import Dict, Any, Callable |
|
from collections import namedtuple |
|
from easydict import EasyDict |
|
import gym |
|
import torch |
|
|
|
from ding.torch_utils import to_device |
|
|
|
|
|
class PolicyFactory: |
|
""" |
|
Overview: |
|
Policy factory class, used to generate different policies for general purpose. Such as random action policy, \ |
|
which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0. |
|
Interfaces: |
|
``get_random_policy`` |
|
""" |
|
|
|
@staticmethod |
|
def get_random_policy( |
|
policy: 'Policy.collect_mode', |
|
action_space: 'gym.spaces.Space' = None, |
|
forward_fn: Callable = None, |
|
) -> 'Policy.collect_mode': |
|
""" |
|
Overview: |
|
According to the given action space, define the forward function of the random policy, then pack it with \ |
|
other interfaces of the given policy, and return the final collect mode interfaces of policy. |
|
Arguments: |
|
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. |
|
- action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style. |
|
- forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \ |
|
and pass it to this function, note you should set ``action_space`` to ``None`` in this case. |
|
Returns: |
|
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. |
|
""" |
|
assert not (action_space is None and forward_fn is None) |
|
random_collect_function = namedtuple( |
|
'random_collect_function', [ |
|
'forward', |
|
'process_transition', |
|
'get_train_sample', |
|
'reset', |
|
'get_attribute', |
|
] |
|
) |
|
|
|
def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]: |
|
|
|
actions = {} |
|
for env_id in data: |
|
if not isinstance(action_space, list): |
|
if isinstance(action_space, gym.spaces.Discrete): |
|
action = torch.LongTensor([action_space.sample()]) |
|
elif isinstance(action_space, gym.spaces.MultiDiscrete): |
|
action = [torch.LongTensor([v]) for v in action_space.sample()] |
|
else: |
|
action = torch.as_tensor(action_space.sample()) |
|
actions[env_id] = {'action': action} |
|
elif 'global_state' in data[env_id].keys(): |
|
|
|
logit = torch.ones_like(data[env_id]['action_mask']) |
|
logit[data[env_id]['action_mask'] == 0.0] = -1e8 |
|
dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit)) |
|
actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)} |
|
else: |
|
|
|
actions[env_id] = { |
|
'action': torch.as_tensor([action_space_agent.sample() for action_space_agent in action_space]), |
|
'logit': torch.ones([len(action_space), action_space[0].n]) |
|
} |
|
return actions |
|
|
|
def reset(*args, **kwargs) -> None: |
|
pass |
|
|
|
if action_space is None: |
|
return random_collect_function( |
|
forward_fn, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute |
|
) |
|
elif forward_fn is None: |
|
return random_collect_function( |
|
forward, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute |
|
) |
|
|
|
|
|
def get_random_policy( |
|
cfg: EasyDict, |
|
policy: 'Policy.collect_mode', |
|
env: 'BaseEnvManager' |
|
) -> 'Policy.collect_mode': |
|
""" |
|
Overview: |
|
The entry function to get the corresponding random policy. If a policy needs special data items in a \ |
|
transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy. |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): The EasyDict-type dict configuration. |
|
- policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy. |
|
- env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \ |
|
action generation. |
|
Returns: |
|
- random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy. |
|
""" |
|
if cfg.policy.get('transition_with_policy_data', False): |
|
return policy |
|
else: |
|
action_space = env.action_space |
|
return PolicyFactory.get_random_policy(policy, action_space=action_space) |
|
|