File size: 5,311 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from typing import Callable, Tuple, Union
import torch
from torch import Tensor
from ding.torch_utils import fold_batch, unfold_batch
from ding.rl_utils import generalized_lambda_returns
from ding.torch_utils.network.dreamer import static_scan
def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor],
Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""
Overview:
Evaluate (observation, action) pairs along the trajectory
Arguments:
- obss (:obj:`torch.Tensor`): the observations along the trajectory
- actions (:obj:`torch.Size`): the actions along the trajectory
- q_critic_fn (:obj:`Callable`): the unified API :math:`Q(S_t, A_t)`
Returns:
- q_value (:obj:`torch.Tensor`): the action-value function evaluated along the trajectory
Shapes:
:math:`N`: time step
:math:`B`: batch size
:math:`O`: observation dimension
:math:`A`: action dimension
- obss: [N, B, O]
- actions: [N, B, A]
- q_value: [N, B]
"""
obss, dim = fold_batch(obss, 1)
actions, _ = fold_batch(actions, 1)
q_values = q_critic_fn(obss, actions)
# twin critic
if isinstance(q_values, list):
return [unfold_batch(q_values[0], dim), unfold_batch(q_values[1], dim)]
return unfold_batch(q_values, dim)
def imagine(cfg, world_model, start, actor, horizon, repeats=None):
dynamics = world_model.dynamics
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = dynamics.get_feat(state)
inp = feat.detach()
action = actor(inp).sample()
succ = dynamics.img_step(state, action, sample=cfg.imag_sample)
return succ, feat, action
succ, feats, actions = static_scan(step, [torch.arange(horizon)], (start, None, None))
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
return feats, states, actions
def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent):
if "discount" in world_model.heads:
inp = world_model.dynamics.get_feat(imag_state)
discount = cfg.discount * world_model.heads["discount"](inp).mean
# TODO whether to detach
discount = discount.detach()
else:
discount = cfg.discount * torch.ones_like(reward)
value = critic(imag_feat).mode()
# value(imag_horizon, 16*64, 1)
# action(imag_horizon, 16*64, ch)
# discount(imag_horizon, 16*64, 1)
target = generalized_lambda_returns(value, reward[:-1], discount[:-1], cfg.lambda_)
weights = torch.cumprod(torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach()
return target, weights, value[:-1]
def compute_actor_loss(
cfg,
actor,
reward_ema,
imag_feat,
imag_state,
imag_action,
target,
actor_ent,
state_ent,
weights,
base,
):
metrics = {}
inp = imag_feat.detach()
policy = actor(inp)
actor_ent = policy.entropy()
# Q-val for actor is not transformed using symlog
if cfg.reward_EMA:
offset, scale = reward_ema(target)
normed_target = (target - offset) / scale
normed_base = (base - offset) / scale
adv = normed_target - normed_base
metrics.update(tensorstats(normed_target, "normed_target"))
values = reward_ema.values
metrics["EMA_005"] = values[0].detach().cpu().numpy().item()
metrics["EMA_095"] = values[1].detach().cpu().numpy().item()
actor_target = adv
if cfg.actor_entropy > 0:
actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None]
actor_target += actor_entropy
metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy().item()
if cfg.actor_state_entropy > 0:
state_entropy = cfg.actor_state_entropy * state_ent[:-1]
actor_target += state_entropy
metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy().item()
actor_loss = -torch.mean(weights[:-1] * actor_target)
return actor_loss, metrics
class RewardEMA(object):
"""running mean and std"""
def __init__(self, device, alpha=1e-2):
self.device = device
self.values = torch.zeros((2, )).to(device)
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x):
flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
offset = self.values[0]
return offset.detach(), scale.detach()
def tensorstats(tensor, prefix=None):
metrics = {
'mean': torch.mean(tensor).detach().cpu().numpy(),
'std': torch.std(tensor).detach().cpu().numpy(),
'min': torch.min(tensor).detach().cpu().numpy(),
'max': torch.max(tensor).detach().cpu().numpy(),
}
if prefix:
metrics = {f'{prefix}_{k}': v.item() for k, v in metrics.items()}
return metrics
|