gomoku / DI-engine /ding /policy /common_utils.py
zjowowen's picture
init space
079c32c
raw
history blame
5.3 kB
from typing import List, Any, Dict, Callable
import torch
import numpy as np
import treetensor.torch as ttorch
from ding.utils.data import default_collate
from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze
def default_preprocess_learn(
data: List[Any],
use_priority_IS_weight: bool = False,
use_priority: bool = False,
use_nstep: bool = False,
ignore_done: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Overview:
Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \
ignore done, nstep and priority IS weight.
Arguments:
- data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor.
- use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \
will set the weight of each sample to the priority IS weight.
- use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight.
- use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward.
- ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0.
Returns:
- data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \
the following model forward and loss computation.
"""
# data preprocess
elem = data[0]
if isinstance(elem['action'], (np.ndarray, torch.Tensor)) and elem['action'].dtype in [np.int64, torch.int64]:
data = default_collate(data, cat_1dim=True) # for discrete action
else:
data = default_collate(data, cat_1dim=False) # for continuous action
if 'value' in data and data['value'].dim() == 2 and data['value'].shape[1] == 1:
data['value'] = data['value'].squeeze(-1)
if 'adv' in data and data['adv'].dim() == 2 and data['adv'].shape[1] == 1:
data['adv'] = data['adv'].squeeze(-1)
if ignore_done:
data['done'] = torch.zeros_like(data['done']).float()
else:
data['done'] = data['done'].float()
if data['done'].dim() == 2 and data['done'].shape[1] == 1:
data['done'] = data['done'].squeeze(-1)
if use_priority_IS_weight:
assert use_priority, "Use IS Weight correction, but Priority is not used."
if use_priority and use_priority_IS_weight:
if 'priority_IS' in data:
data['weight'] = data['priority_IS']
else: # for compability
data['weight'] = data['IS']
else:
data['weight'] = data.get('weight', None)
if use_nstep:
# reward reshaping for n-step
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
# reward: (batch_size, nstep) -> (nstep, batch_size)
data['reward'] = reward.permute(1, 0).contiguous()
else:
if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
data['reward'] = data['reward'].squeeze(-1)
return data
def single_env_forward_wrapper(forward_fn: Callable) -> Callable:
"""
Overview:
Wrap policy to support gym-style interaction between policy and single environment.
Arguments:
- forward_fn (:obj:`Callable`): The original forward function of policy.
Returns:
- wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
Examples:
>>> env = gym.make('CartPole-v0')
>>> policy = DQNPolicy(...)
>>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward)
>>> obs = env.reset()
>>> action = forward_fn(obs)
>>> next_obs, rew, done, info = env.step(action)
"""
def _forward(obs):
obs = {0: unsqueeze(to_tensor(obs))}
action = forward_fn(obs)[0]['action']
action = to_ndarray(squeeze(action))
return action
return _forward
def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable:
"""
Overview:
Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data.
Arguments:
- forward_fn (:obj:`Callable`): The original forward function of policy.
- cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda.
Returns:
- wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
Examples:
>>> env = gym.make('CartPole-v0')
>>> policy = PPOFPolicy(...)
>>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval)
>>> obs = env.reset()
>>> action = forward_fn(obs)
>>> next_obs, rew, done, info = env.step(action)
"""
def _forward(obs):
# unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
obs = ttorch.as_tensor(obs).unsqueeze(0)
if cuda and torch.cuda.is_available():
obs = obs.cuda()
action = forward_fn(obs).action
# squeeze means delete batch dim, i.e. (1, A) -> (A, )
action = action.squeeze(0).cpu().numpy()
return action
return _forward