from typing import Callable, Tuple, Union
import torch
from torch import Tensor
from ding.torch_utils import fold_batch, unfold_batch
from ding.rl_utils import generalized_lambda_returns
from import static_scan
def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor],
Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]:
Evaluate (observation, action) pairs along the trajectory
- obss (:obj:`torch.Tensor`): the observations along the trajectory
- actions (:obj:`torch.Size`): the actions along the trajectory
- q_critic_fn (:obj:`Callable`): the unified API :math:`Q(S_t, A_t)`
- q_value (:obj:`torch.Tensor`): the action-value function evaluated along the trajectory
:math:`N`: time step
:math:`B`: batch size
:math:`O`: observation dimension
:math:`A`: action dimension
- obss: [N, B, O]
- actions: [N, B, A]
- q_value: [N, B]
obss, dim = fold_batch(obss, 1)
actions, _ = fold_batch(actions, 1)
q_values = q_critic_fn(obss, actions)
# twin critic
if isinstance(q_values, list):
return [unfold_batch(q_values[0], dim), unfold_batch(q_values[1], dim)]
return unfold_batch(q_values, dim)
def imagine(cfg, world_model, start, actor, horizon, repeats=None):
dynamics = world_model.dynamics
flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
start = {k: flatten(v) for k, v in start.items()}
def step(prev, _):
state, _, _ = prev
feat = dynamics.get_feat(state)
inp = feat.detach()
action = actor(inp).sample()
succ = dynamics.img_step(state, action, sample=cfg.imag_sample)
return succ, feat, action
succ, feats, actions = static_scan(step, [torch.arange(horizon)], (start, None, None))
states = {k:[start[k][None], v[:-1]], 0) for k, v in succ.items()}
return feats, states, actions
def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent):
if "discount" in world_model.heads:
inp = world_model.dynamics.get_feat(imag_state)
discount = * world_model.heads["discount"](inp).mean
# TODO whether to detach
discount = discount.detach()
discount = * torch.ones_like(reward)
value = critic(imag_feat).mode()
# value(imag_horizon, 16*64, 1)
# action(imag_horizon, 16*64, ch)
# discount(imag_horizon, 16*64, 1)
target = generalized_lambda_returns(value, reward[:-1], discount[:-1], cfg.lambda_)
weights = torch.cumprod([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach()
return target, weights, value[:-1]
def compute_actor_loss(
metrics = {}
inp = imag_feat.detach()
policy = actor(inp)
actor_ent = policy.entropy()
# Q-val for actor is not transformed using symlog
if cfg.reward_EMA:
offset, scale = reward_ema(target)
normed_target = (target - offset) / scale
normed_base = (base - offset) / scale
adv = normed_target - normed_base
metrics.update(tensorstats(normed_target, "normed_target"))
values = reward_ema.values
metrics["EMA_005"] = values[0].detach().cpu().numpy().item()
metrics["EMA_095"] = values[1].detach().cpu().numpy().item()
actor_target = adv
if cfg.actor_entropy > 0:
actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None]
actor_target += actor_entropy
metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy().item()
if cfg.actor_state_entropy > 0:
state_entropy = cfg.actor_state_entropy * state_ent[:-1]
actor_target += state_entropy
metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy().item()
actor_loss = -torch.mean(weights[:-1] * actor_target)
return actor_loss, metrics
class RewardEMA(object):
"""running mean and std"""
def __init__(self, device, alpha=1e-2):
self.device = device
self.values = torch.zeros((2, )).to(device)
self.alpha = alpha
self.range = torch.tensor([0.05, 0.95]).to(device)
def __call__(self, x):
flat_x = torch.flatten(x.detach())
x_quantile = torch.quantile(input=flat_x, q=self.range)
self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
scale = torch.clip(self.values[1] - self.values[0], min=1.0)
offset = self.values[0]
return offset.detach(), scale.detach()
def tensorstats(tensor, prefix=None):
metrics = {
'mean': torch.mean(tensor).detach().cpu().numpy(),
'std': torch.std(tensor).detach().cpu().numpy(),
'min': torch.min(tensor).detach().cpu().numpy(),
'max': torch.max(tensor).detach().cpu().numpy(),
if prefix:
metrics = {f'{prefix}_{k}': v.item() for k, v in metrics.items()}
return metrics