|
from typing import Callable, Tuple, Union |
|
import torch |
|
from torch import Tensor |
|
from ding.torch_utils import fold_batch, unfold_batch |
|
from ding.rl_utils import generalized_lambda_returns |
|
from ding.torch_utils.network.dreamer import static_scan |
|
|
|
|
|
def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor], |
|
Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]: |
|
""" |
|
Overview: |
|
Evaluate (observation, action) pairs along the trajectory |
|
|
|
Arguments: |
|
- obss (:obj:`torch.Tensor`): the observations along the trajectory |
|
- actions (:obj:`torch.Size`): the actions along the trajectory |
|
- q_critic_fn (:obj:`Callable`): the unified API :math:`Q(S_t, A_t)` |
|
|
|
Returns: |
|
- q_value (:obj:`torch.Tensor`): the action-value function evaluated along the trajectory |
|
|
|
Shapes: |
|
:math:`N`: time step |
|
:math:`B`: batch size |
|
:math:`O`: observation dimension |
|
:math:`A`: action dimension |
|
|
|
- obss: [N, B, O] |
|
- actions: [N, B, A] |
|
- q_value: [N, B] |
|
|
|
""" |
|
obss, dim = fold_batch(obss, 1) |
|
actions, _ = fold_batch(actions, 1) |
|
q_values = q_critic_fn(obss, actions) |
|
|
|
if isinstance(q_values, list): |
|
return [unfold_batch(q_values[0], dim), unfold_batch(q_values[1], dim)] |
|
return unfold_batch(q_values, dim) |
|
|
|
|
|
def imagine(cfg, world_model, start, actor, horizon, repeats=None): |
|
dynamics = world_model.dynamics |
|
flatten = lambda x: x.reshape([-1] + list(x.shape[2:])) |
|
start = {k: flatten(v) for k, v in start.items()} |
|
|
|
def step(prev, _): |
|
state, _, _ = prev |
|
feat = dynamics.get_feat(state) |
|
inp = feat.detach() |
|
action = actor(inp).sample() |
|
succ = dynamics.img_step(state, action, sample=cfg.imag_sample) |
|
return succ, feat, action |
|
|
|
succ, feats, actions = static_scan(step, [torch.arange(horizon)], (start, None, None)) |
|
states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()} |
|
|
|
return feats, states, actions |
|
|
|
|
|
def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent): |
|
if "discount" in world_model.heads: |
|
inp = world_model.dynamics.get_feat(imag_state) |
|
discount = cfg.discount * world_model.heads["discount"](inp).mean |
|
|
|
discount = discount.detach() |
|
else: |
|
discount = cfg.discount * torch.ones_like(reward) |
|
|
|
value = critic(imag_feat).mode() |
|
|
|
|
|
|
|
target = generalized_lambda_returns(value, reward[:-1], discount[:-1], cfg.lambda_) |
|
weights = torch.cumprod(torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach() |
|
return target, weights, value[:-1] |
|
|
|
|
|
def compute_actor_loss( |
|
cfg, |
|
actor, |
|
reward_ema, |
|
imag_feat, |
|
imag_state, |
|
imag_action, |
|
target, |
|
actor_ent, |
|
state_ent, |
|
weights, |
|
base, |
|
): |
|
metrics = {} |
|
inp = imag_feat.detach() |
|
policy = actor(inp) |
|
actor_ent = policy.entropy() |
|
|
|
if cfg.reward_EMA: |
|
offset, scale = reward_ema(target) |
|
normed_target = (target - offset) / scale |
|
normed_base = (base - offset) / scale |
|
adv = normed_target - normed_base |
|
metrics.update(tensorstats(normed_target, "normed_target")) |
|
values = reward_ema.values |
|
metrics["EMA_005"] = values[0].detach().cpu().numpy().item() |
|
metrics["EMA_095"] = values[1].detach().cpu().numpy().item() |
|
|
|
actor_target = adv |
|
if cfg.actor_entropy > 0: |
|
actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None] |
|
actor_target += actor_entropy |
|
metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy().item() |
|
if cfg.actor_state_entropy > 0: |
|
state_entropy = cfg.actor_state_entropy * state_ent[:-1] |
|
actor_target += state_entropy |
|
metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy().item() |
|
actor_loss = -torch.mean(weights[:-1] * actor_target) |
|
return actor_loss, metrics |
|
|
|
|
|
class RewardEMA(object): |
|
"""running mean and std""" |
|
|
|
def __init__(self, device, alpha=1e-2): |
|
self.device = device |
|
self.values = torch.zeros((2, )).to(device) |
|
self.alpha = alpha |
|
self.range = torch.tensor([0.05, 0.95]).to(device) |
|
|
|
def __call__(self, x): |
|
flat_x = torch.flatten(x.detach()) |
|
x_quantile = torch.quantile(input=flat_x, q=self.range) |
|
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values |
|
scale = torch.clip(self.values[1] - self.values[0], min=1.0) |
|
offset = self.values[0] |
|
return offset.detach(), scale.detach() |
|
|
|
|
|
def tensorstats(tensor, prefix=None): |
|
metrics = { |
|
'mean': torch.mean(tensor).detach().cpu().numpy(), |
|
'std': torch.std(tensor).detach().cpu().numpy(), |
|
'min': torch.min(tensor).detach().cpu().numpy(), |
|
'max': torch.max(tensor).detach().cpu().numpy(), |
|
} |
|
if prefix: |
|
metrics = {f'{prefix}_{k}': v.item() for k, v in metrics.items()} |
|
return metrics |
|
|