import copy |
import numpy as np |
from collections import namedtuple |
from typing import Union, Optional, Callable |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from ding.hpc_rl import hpc_wrapper |
from ding.rl_utils.value_rescale import value_transform, value_inv_transform |
from ding.torch_utils import to_tensor |
q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight']) |
def discount_cumsum(x, gamma: float = 1.0) -> np.ndarray: |
assert abs(gamma - 1.) < 1e-5, "gamma equals to 1.0 in original decision transformer paper" |
disc_cumsum = np.zeros_like(x) |
disc_cumsum[-1] = x[-1] |
for t in reversed(range(x.shape[0] - 1)): |
disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1] |
return disc_cumsum |
def q_1step_td_error( |
data: namedtuple, |
gamma: float, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none') |
) -> torch.Tensor: |
""" |
Overview: |
1 step td_error, support single agent case and multi agent case. |
Arguments: |
- data (:obj:`q_1step_td_data`): The input data, q_1step_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
Returns: |
- loss (:obj:`torch.Tensor`): 1step td error |
Shapes: |
- data (:obj:`q_1step_td_data`): the q_1step_td_data containing\ |
['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- act (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_act (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`( , B)` |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight |
Examples: |
>>> action_dim = 4 |
>>> data = q_1step_td_data( |
>>> q=torch.randn(3, action_dim), |
>>> next_q=torch.randn(3, action_dim), |
>>> act=torch.randint(0, action_dim, (3,)), |
>>> next_act=torch.randint(0, action_dim, (3,)), |
>>> reward=torch.randn(3), |
>>> done=torch.randint(0, 2, (3,)).bool(), |
>>> weight=torch.ones(3), |
>>> ) |
>>> loss = q_1step_td_error(data, 0.99) |
""" |
q, next_q, act, next_act, reward, done, weight = data |
assert len(act.shape) == 1, act.shape |
assert len(reward.shape) == 1, reward.shape |
batch_range = torch.arange(act.shape[0]) |
if weight is None: |
weight = torch.ones_like(reward) |
q_s_a = q[batch_range, act] |
target_q_s_a = next_q[batch_range, next_act] |
target_q_s_a = gamma * (1 - done) * target_q_s_a + reward |
return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean() |
m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight']) |
def m_q_1step_td_error( |
data: namedtuple, |
gamma: float, |
tau: float, |
alpha: float, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none') |
) -> torch.Tensor: |
""" |
Overview: |
Munchausen td_error for DQN algorithm, support 1 step td error. |
Arguments: |
- data (:obj:`m_q_1step_td_data`): The input data, m_q_1step_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- tau (:obj:`float`): Entropy factor for Munchausen DQN |
- alpha (:obj:`float`): Discount factor for Munchausen term |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
Returns: |
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor |
Shapes: |
- data (:obj:`m_q_1step_td_data`): the m_q_1step_td_data containing\ |
['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- target_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- act (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`( , B)` |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight |
Examples: |
>>> action_dim = 4 |
>>> data = m_q_1step_td_data( |
>>> q=torch.randn(3, action_dim), |
>>> target_q=torch.randn(3, action_dim), |
>>> next_q=torch.randn(3, action_dim), |
>>> act=torch.randint(0, action_dim, (3,)), |
>>> reward=torch.randn(3), |
>>> done=torch.randint(0, 2, (3,)), |
>>> weight=torch.ones(3), |
>>> ) |
>>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01) |
""" |
q, target_q, next_q, act, reward, done, weight = data |
lower_bound = -1 |
assert len(act.shape) == 1, act.shape |
assert len(reward.shape) == 1, reward.shape |
batch_range = torch.arange(act.shape[0]) |
if weight is None: |
weight = torch.ones_like(reward) |
q_s_a = q[batch_range, act] |
target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1) |
logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1) |
log_pi = target_q - target_v_s - tau * logsum |
act_get = act.unsqueeze(-1) |
munchausen_addon = log_pi.gather(1, act_get) |
muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1) |
target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1) |
logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1) |
tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next |
pi_target = F.softmax((next_q - target_v_s_next) / tau) |
target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1) |
target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a |
td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1) |
with torch.no_grad(): |
top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0] |
action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean() |
clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound) |
clipfrac = torch.as_tensor(clipped).float() |
return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac |
q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight']) |
def q_v_1step_td_error( |
data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none') |
) -> torch.Tensor: |
""" |
Overview: |
td_error between q and v value for SAC algorithm, support 1 step td error. |
Arguments: |
- data (:obj:`q_v_1step_td_data`): The input data, q_v_1step_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
Returns: |
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor |
Shapes: |
- data (:obj:`q_v_1step_td_data`): the q_v_1step_td_data containing\ |
['q', 'v', 'act', 'reward', 'done', 'weight'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- v (:obj:`torch.FloatTensor`): :math:`(B, )` |
- act (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`( , B)` |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight |
Examples: |
>>> action_dim = 4 |
>>> data = q_v_1step_td_data( |
>>> q=torch.randn(3, action_dim), |
>>> v=torch.randn(3), |
>>> act=torch.randint(0, action_dim, (3,)), |
>>> reward=torch.randn(3), |
>>> done=torch.randint(0, 2, (3,)), |
>>> weight=torch.ones(3), |
>>> ) |
>>> loss = q_v_1step_td_error(data, 0.99) |
""" |
q, v, act, reward, done, weight = data |
if len(act.shape) == 1: |
assert len(reward.shape) == 1, reward.shape |
batch_range = torch.arange(act.shape[0]) |
if weight is None: |
weight = torch.ones_like(reward) |
q_s_a = q[batch_range, act] |
target_q_s_a = gamma * (1 - done) * v + reward |
else: |
assert len(reward.shape) == 1, reward.shape |
batch_range = torch.arange(act.shape[0]) |
actor_range = torch.arange(act.shape[1]) |
batch_actor_range = torch.arange(act.shape[0] * act.shape[1]) |
if weight is None: |
weight = torch.ones_like(act) |
temp_q = q.reshape(act.shape[0] * act.shape[1], -1) |
temp_act = act.reshape(act.shape[0] * act.shape[1]) |
q_s_a = temp_q[batch_actor_range, temp_act] |
q_s_a = q_s_a.reshape(act.shape[0], act.shape[1]) |
target_q_s_a = gamma * (1 - done).unsqueeze(1) * v + reward.unsqueeze(1) |
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) |
return (td_error_per_sample * weight).mean(), td_error_per_sample |
def view_similar(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
size = list(x.shape) + [1 for _ in range(len(target.shape) - len(x.shape))] |
return x.view(*size) |
nstep_return_data = namedtuple('nstep_return_data', ['reward', 'next_value', 'done']) |
def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_gamma: Optional[torch.Tensor] = None): |
''' |
Overview: |
Calculate nstep return for DQN algorithm, support single agent case and multi agent case. |
Arguments: |
- data (:obj:`nstep_return_data`): The input data, nstep_return_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- nstep (:obj:`int`): nstep num |
- value_gamma (:obj:`torch.Tensor`): Discount factor for value |
Returns: |
- return (:obj:`torch.Tensor`): nstep return |
Shapes: |
- data (:obj:`nstep_return_data`): the nstep_return_data containing\ |
['reward', 'next_value', 'done'] |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- next_value (:obj:`torch.FloatTensor`): :math:`(, B)` |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
Examples: |
>>> data = nstep_return_data( |
>>> reward=torch.randn(3, 3), |
>>> next_value=torch.randn(3), |
>>> done=torch.randint(0, 2, (3,)), |
>>> ) |
>>> loss = nstep_return(data, 0.99, 3) |
''' |
reward, next_value, done = data |
assert reward.shape[0] == nstep |
device = reward.device |
if isinstance(gamma, float): |
reward_factor = torch.ones(nstep).to(device) |
for i in range(1, nstep): |
reward_factor[i] = gamma * reward_factor[i - 1] |
reward_factor = view_similar(reward_factor, reward) |
return_tmp = reward.mul(reward_factor).sum(0) |
if value_gamma is None: |
return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done) |
else: |
return_ = return_tmp + value_gamma * next_value * (1 - done) |
elif isinstance(gamma, list): |
reward_factor = torch.ones([nstep + 1, done.shape[0]]).to(device) |
for i in range(1, nstep + 1): |
reward_factor[i] = torch.stack(gamma, dim=0).to(device) * reward_factor[i - 1] |
reward_factor = view_similar(reward_factor, reward) |
return_tmp = reward.mul(reward_factor[:nstep]).sum(0) |
return_ = return_tmp + reward_factor[nstep] * next_value * (1 - done) |
else: |
raise TypeError("The type of gamma should be float or list") |
return return_ |
dist_1step_td_data = namedtuple( |
'dist_1step_td_data', ['dist', 'next_dist', 'act', 'next_act', 'reward', 'done', 'weight'] |
) |
def dist_1step_td_error( |
data: namedtuple, |
gamma: float, |
v_min: float, |
v_max: float, |
n_atom: int, |
) -> torch.Tensor: |
""" |
Overview: |
1 step td_error for distributed q-learning based algorithm |
Arguments: |
- data (:obj:`dist_1step_td_data`): The input data, dist_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- v_min (:obj:`float`): The min value of support |
- v_max (:obj:`float`): The max value of support |
- n_atom (:obj:`int`): The num of atom |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
Shapes: |
- data (:obj:`dist_1step_td_data`): the dist_1step_td_data containing\ |
['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] |
- dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] |
- next_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` |
- act (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_act (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(, B)` |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight |
Examples: |
>>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) |
>>> next_dist = torch.randn(4, 3, 51).abs() |
>>> act = torch.randint(0, 3, (4,)) |
>>> next_act = torch.randint(0, 3, (4,)) |
>>> reward = torch.randn(4) |
>>> done = torch.randint(0, 2, (4,)) |
>>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None) |
>>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51) |
""" |
dist, next_dist, act, next_act, reward, done, weight = data |
device = reward.device |
assert len(reward.shape) == 1, reward.shape |
support = torch.linspace(v_min, v_max, n_atom).to(device) |
delta_z = (v_max - v_min) / (n_atom - 1) |
if len(act.shape) == 1: |
reward = reward.unsqueeze(-1) |
done = done.unsqueeze(-1) |
batch_size = act.shape[0] |
batch_range = torch.arange(batch_size) |
if weight is None: |
weight = torch.ones_like(reward) |
next_dist = next_dist[batch_range, next_act].detach() |
else: |
reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) |
done = done.unsqueeze(-1).repeat(1, act.shape[1]) |
batch_size = act.shape[0] * act.shape[1] |
batch_range = torch.arange(act.shape[0] * act.shape[1]) |
action_dim = dist.shape[2] |
dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) |
reward = reward.reshape(act.shape[0] * act.shape[1], -1) |
done = done.reshape(act.shape[0] * act.shape[1], -1) |
next_dist = next_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) |
next_act = next_act.reshape(act.shape[0] * act.shape[1]) |
next_dist = next_dist[batch_range, next_act].detach() |
next_dist = next_dist.reshape(act.shape[0] * act.shape[1], -1) |
act = act.reshape(act.shape[0] * act.shape[1]) |
if weight is None: |
weight = torch.ones_like(reward) |
target_z = reward + (1 - done) * gamma * support |
target_z = target_z.clamp(min=v_min, max=v_max) |
b = (target_z - v_min) / delta_z |
l = b.floor().long() |
u = b.ceil().long() |
l[(u > 0) * (l == u)] -= 1 |
u[(l < (n_atom - 1)) * (l == u)] += 1 |
proj_dist = torch.zeros_like(next_dist) |
offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, |
n_atom).long().to(device) |
proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) |
proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) |
log_p = torch.log(dist[batch_range, act]) |
loss = -(log_p * proj_dist * weight).sum(-1).mean() |
return loss |
dist_nstep_td_data = namedtuple( |
'dist_1step_td_data', ['dist', 'next_n_dist', 'act', 'next_n_act', 'reward', 'done', 'weight'] |
) |
def shape_fn_dntd(args, kwargs): |
r""" |
Overview: |
Return dntd shape for hpc |
Returns: |
shape: [T, B, N, n_atom] |
""" |
if len(args) <= 0: |
tmp = [kwargs['data'].reward.shape[0]] |
tmp.extend(list(kwargs['data'].dist.shape)) |
else: |
tmp = [args[0].reward.shape[0]] |
tmp.extend(list(args[0].dist.shape)) |
return tmp |
@hpc_wrapper( |
shape_fn=shape_fn_dntd, |
namedtuple_data=True, |
include_args=[0, 1, 2, 3], |
include_kwargs=['data', 'gamma', 'v_min', 'v_max'] |
) |
def dist_nstep_td_error( |
data: namedtuple, |
gamma: float, |
v_min: float, |
v_max: float, |
n_atom: int, |
nstep: int = 1, |
value_gamma: Optional[torch.Tensor] = None, |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single\ |
agent case and multi agent case. |
Arguments: |
- data (:obj:`dist_nstep_td_data`): The input data, dist_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- nstep (:obj:`int`): nstep num, default set to 1 |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
Shapes: |
- data (:obj:`dist_nstep_td_data`): the dist_nstep_td_data containing\ |
['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] |
- dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] |
- next_n_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` |
- act (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_act (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
Examples: |
>>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) |
>>> next_n_dist = torch.randn(4, 3, 51).abs() |
>>> done = torch.randn(4) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> reward = torch.randn(5, 4) |
>>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) |
>>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5) |
""" |
dist, next_n_dist, act, next_n_act, reward, done, weight = data |
device = reward.device |
reward_factor = torch.ones(nstep).to(device) |
for i in range(1, nstep): |
reward_factor[i] = gamma * reward_factor[i - 1] |
reward = torch.matmul(reward_factor, reward) |
support = torch.linspace(v_min, v_max, n_atom).to(device) |
delta_z = (v_max - v_min) / (n_atom - 1) |
if len(act.shape) == 1: |
reward = reward.unsqueeze(-1) |
done = done.unsqueeze(-1) |
batch_size = act.shape[0] |
batch_range = torch.arange(batch_size) |
if weight is None: |
weight = torch.ones_like(reward) |
elif isinstance(weight, float): |
weight = torch.tensor(weight) |
next_n_dist = next_n_dist[batch_range, next_n_act].detach() |
else: |
reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) |
done = done.unsqueeze(-1).repeat(1, act.shape[1]) |
batch_size = act.shape[0] * act.shape[1] |
batch_range = torch.arange(act.shape[0] * act.shape[1]) |
action_dim = dist.shape[2] |
dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) |
reward = reward.reshape(act.shape[0] * act.shape[1], -1) |
done = done.reshape(act.shape[0] * act.shape[1], -1) |
next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) |
next_n_act = next_n_act.reshape(act.shape[0] * act.shape[1]) |
next_n_dist = next_n_dist[batch_range, next_n_act].detach() |
next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], -1) |
act = act.reshape(act.shape[0] * act.shape[1]) |
if weight is None: |
weight = torch.ones_like(reward) |
elif isinstance(weight, float): |
weight = torch.tensor(weight) |
if value_gamma is None: |
target_z = reward + (1 - done) * (gamma ** nstep) * support |
elif isinstance(value_gamma, float): |
value_gamma = torch.tensor(value_gamma).unsqueeze(-1) |
target_z = reward + (1 - done) * value_gamma * support |
else: |
value_gamma = value_gamma.unsqueeze(-1) |
target_z = reward + (1 - done) * value_gamma * support |
target_z = target_z.clamp(min=v_min, max=v_max) |
b = (target_z - v_min) / delta_z |
l = b.floor().long() |
u = b.ceil().long() |
l[(u > 0) * (l == u)] -= 1 |
u[(l < (n_atom - 1)) * (l == u)] += 1 |
proj_dist = torch.zeros_like(next_n_dist) |
offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, |
n_atom).long().to(device) |
proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_n_dist * (u.float() - b)).view(-1)) |
proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_n_dist * (b - l.float())).view(-1)) |
assert (dist[batch_range, act] > 0.0).all(), ("dist act", dist[batch_range, act], "dist:", dist) |
log_p = torch.log(dist[batch_range, act]) |
if len(weight.shape) == 1: |
weight = weight.unsqueeze(-1) |
td_error_per_sample = -(log_p * proj_dist).sum(-1) |
loss = -(log_p * proj_dist * weight).sum(-1).mean() |
return loss, td_error_per_sample |
v_1step_td_data = namedtuple('v_1step_td_data', ['v', 'next_v', 'reward', 'done', 'weight']) |
def v_1step_td_error( |
data: namedtuple, |
gamma: float, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none') |
) -> torch.Tensor: |
''' |
Overview: |
1 step td_error for distributed value based algorithm |
Arguments: |
- data (:obj:`v_1step_td_data`): The input data, v_1step_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
Returns: |
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor |
Shapes: |
- data (:obj:`v_1step_td_data`): the v_1step_td_data containing\ |
['v', 'next_v', 'reward', 'done', 'weight'] |
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] |
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(, B)` |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight |
Examples: |
>>> v = torch.randn(5).requires_grad_(True) |
>>> next_v = torch.randn(5) |
>>> reward = torch.rand(5) |
>>> done = torch.zeros(5) |
>>> data = v_1step_td_data(v, next_v, reward, done, None) |
>>> loss, td_error_per_sample = v_1step_td_error(data, 0.99) |
''' |
v, next_v, reward, done, weight = data |
if weight is None: |
weight = torch.ones_like(v) |
if len(v.shape) == len(reward.shape): |
if done is not None: |
target_v = gamma * (1 - done) * next_v + reward |
else: |
target_v = gamma * next_v + reward |
else: |
if done is not None: |
target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1) |
else: |
target_v = gamma * next_v + reward.unsqueeze(1) |
td_error_per_sample = criterion(v, target_v.detach()) |
return (td_error_per_sample * weight).mean(), td_error_per_sample |
v_nstep_td_data = namedtuple('v_nstep_td_data', ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']) |
def v_nstep_td_error( |
data: namedtuple, |
gamma: float, |
nstep: int = 1, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none') |
) -> torch.Tensor: |
r""" |
Overview: |
Multistep (n step) td_error for distributed value based algorithm |
Arguments: |
- data (:obj:`dist_nstep_td_data`): The input data, v_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- nstep (:obj:`int`): nstep num, default set to 1 |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
Shapes: |
- data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing\ |
['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'] |
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] |
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight |
- value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step\ |
we use value_gamma as the gamma discount value for next_v rather than gamma**n_step |
Examples: |
>>> v = torch.randn(5).requires_grad_(True) |
>>> next_v = torch.randn(5) |
>>> reward = torch.rand(5, 5) |
>>> done = torch.zeros(5) |
>>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) |
>>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) |
""" |
v, next_n_v, reward, done, weight, value_gamma = data |
if weight is None: |
weight = torch.ones_like(v) |
target_v = nstep_return(nstep_return_data(reward, next_n_v, done), gamma, nstep, value_gamma) |
td_error_per_sample = criterion(v, target_v.detach()) |
return (td_error_per_sample * weight).mean(), td_error_per_sample |
q_nstep_td_data = namedtuple( |
'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'] |
) |
dqfd_nstep_td_data = namedtuple( |
'dqfd_nstep_td_data', [ |
'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_one_step', 'weight', 'new_n_q_one_step', |
'next_n_action_one_step', 'is_expert' |
] |
) |
def shape_fn_qntd(args, kwargs): |
r""" |
Overview: |
Return qntd shape for hpc |
Returns: |
shape: [T, B, N] |
""" |
if len(args) <= 0: |
tmp = [kwargs['data'].reward.shape[0]] |
tmp.extend(list(kwargs['data'].q.shape)) |
else: |
tmp = [args[0].reward.shape[0]] |
tmp.extend(list(args[0].q.shape)) |
return tmp |
@hpc_wrapper(shape_fn=shape_fn_qntd, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma']) |
def q_nstep_td_error( |
data: namedtuple, |
gamma: Union[float, list], |
nstep: int = 1, |
cum_reward: bool = False, |
value_gamma: Optional[torch.Tensor] = None, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none'), |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error for q-learning based algorithm |
Arguments: |
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data |
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- nstep (:obj:`int`): nstep num, default set to 1 |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'reward', 'done'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` |
Examples: |
>>> next_q = torch.randn(4, 3) |
>>> done = torch.randn(4) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> nstep =3 |
>>> q = torch.randn(4, 3).requires_grad_(True) |
>>> reward = torch.rand(nstep, 4) |
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
>>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep) |
""" |
q, next_n_q, action, next_n_action, reward, done, weight = data |
if weight is None: |
weight = torch.ones_like(reward) |
if len(action.shape) == 1: |
action = action.unsqueeze(-1) |
elif len(action.shape) > 1: |
reward = reward.unsqueeze(-1) |
weight = weight.unsqueeze(-1) |
done = done.unsqueeze(-1) |
if value_gamma is not None: |
value_gamma = value_gamma.unsqueeze(-1) |
q_s_a = q.gather(-1, action).squeeze(-1) |
target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) |
if cum_reward: |
if value_gamma is None: |
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) |
else: |
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) |
else: |
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) |
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) |
return (td_error_per_sample * weight).mean(), td_error_per_sample |
def bdq_nstep_td_error( |
data: namedtuple, |
gamma: Union[float, list], |
nstep: int = 1, |
cum_reward: bool = False, |
value_gamma: Optional[torch.Tensor] = None, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none'), |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \ |
Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946. |
In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \ |
calculation method of n-step, i.e., TD-error: |
Arguments: |
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data |
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- nstep (:obj:`int`): nstep num, default set to 1 |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \ |
['q', 'next_n_q', 'action', 'reward', 'done'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, D)` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, D)` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` |
Examples: |
>>> action_per_branch = 3 |
>>> next_q = torch.randn(8, 6, action_per_branch) |
>>> done = torch.randn(8) |
>>> action = torch.randint(0, action_per_branch, size=(8, 6)) |
>>> next_action = torch.randint(0, action_per_branch, size=(8, 6)) |
>>> nstep =3 |
>>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True) |
>>> reward = torch.rand(nstep, 8) |
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
>>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep) |
""" |
q, next_n_q, action, next_n_action, reward, done, weight = data |
if weight is None: |
weight = torch.ones_like(reward) |
reward = reward.unsqueeze(-1) |
done = done.unsqueeze(-1) |
if value_gamma is not None: |
value_gamma = value_gamma.unsqueeze(-1) |
q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1) |
target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) |
if cum_reward: |
if value_gamma is None: |
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) |
else: |
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) |
else: |
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) |
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) |
td_error_per_sample = td_error_per_sample.mean(-1) |
return (td_error_per_sample * weight).mean(), td_error_per_sample |
def shape_fn_qntd_rescale(args, kwargs): |
r""" |
Overview: |
Return qntd_rescale shape for hpc |
Returns: |
shape: [T, B, N] |
""" |
if len(args) <= 0: |
tmp = [kwargs['data'].reward.shape[0]] |
tmp.extend(list(kwargs['data'].q.shape)) |
else: |
tmp = [args[0].reward.shape[0]] |
tmp.extend(list(args[0].q.shape)) |
return tmp |
@hpc_wrapper( |
shape_fn=shape_fn_qntd_rescale, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma'] |
) |
def q_nstep_td_error_with_rescale( |
data: namedtuple, |
gamma: Union[float, list], |
nstep: int = 1, |
value_gamma: Optional[torch.Tensor] = None, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none'), |
trans_fn: Callable = value_transform, |
inv_trans_fn: Callable = value_inv_transform, |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error with value rescaling |
Arguments: |
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- nstep (:obj:`int`): nstep num, default set to 1 |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- trans_fn (:obj:`Callable`): Value transfrom function, default to value_transform\ |
(refer to rl_utils/value_rescale.py) |
- inv_trans_fn (:obj:`Callable`): Value inverse transfrom function, default to value_inv_transform\ |
(refer to rl_utils/value_rescale.py) |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'reward', 'done'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
Examples: |
>>> next_q = torch.randn(4, 3) |
>>> done = torch.randn(4) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> nstep =3 |
>>> q = torch.randn(4, 3).requires_grad_(True) |
>>> reward = torch.rand(nstep, 4) |
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
>>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep) |
""" |
q, next_n_q, action, next_n_action, reward, done, weight = data |
assert len(action.shape) == 1, action.shape |
if weight is None: |
weight = torch.ones_like(action) |
batch_range = torch.arange(action.shape[0]) |
q_s_a = q[batch_range, action] |
target_q_s_a = next_n_q[batch_range, next_n_action] |
target_q_s_a = inv_trans_fn(target_q_s_a) |
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) |
target_q_s_a = trans_fn(target_q_s_a) |
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) |
return (td_error_per_sample * weight).mean(), td_error_per_sample |
def dqfd_nstep_td_error( |
data: namedtuple, |
gamma: float, |
lambda_n_step_td: float, |
lambda_supervised_loss: float, |
margin_function: float, |
lambda_one_step_td: float = 1., |
nstep: int = 1, |
cum_reward: bool = False, |
value_gamma: Optional[torch.Tensor] = None, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none'), |
) -> torch.Tensor: |
""" |
Overview: |
Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd |
Arguments: |
- data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss |
- gamma (:obj:`float`): discount factor |
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data |
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- nstep (:obj:`int`): nstep num, default set to 10 |
Returns: |
- loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor |
- td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ |
+ supervised margin loss, 1-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ |
, 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` |
- new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` |
- next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` |
- is_expert (:obj:`int`) : 0 or 1 |
Examples: |
>>> next_q = torch.randn(4, 3) |
>>> done = torch.randn(4) |
>>> done_1 = torch.randn(4) |
>>> next_q_one_step = torch.randn(4, 3) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> next_action_one_step = torch.randint(0, 3, size=(4, )) |
>>> is_expert = torch.ones((4)) |
>>> nstep = 3 |
>>> q = torch.randn(4, 3).requires_grad_(True) |
>>> reward = torch.rand(nstep, 4) |
>>> data = dqfd_nstep_td_data( |
>>> q, next_q, action, next_action, reward, done, done_1, None, |
>>> next_q_one_step, next_action_one_step, is_expert |
>>> ) |
>>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( |
>>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, |
>>> margin_function=0.8, nstep=nstep |
>>> ) |
""" |
q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \ |
is_expert = data |
assert len(action.shape) == 1, action.shape |
if weight is None: |
weight = torch.ones_like(action) |
batch_range = torch.arange(action.shape[0]) |
q_s_a = q[batch_range, action] |
target_q_s_a = next_n_q[batch_range, next_n_action] |
target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] |
if cum_reward: |
if value_gamma is None: |
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) |
else: |
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) |
else: |
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) |
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) |
nstep = 1 |
reward = reward[0].unsqueeze(0) |
value_gamma = None |
if cum_reward: |
if value_gamma is None: |
target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step) |
else: |
target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step) |
else: |
target_q_s_a_one_step = nstep_return( |
nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma |
) |
td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) |
device = q_s_a.device |
device_cpu = torch.device('cpu') |
l = margin_function * torch.ones_like(q).to(device_cpu) |
l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)) |
JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) |
return ( |
( |
( |
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + |
lambda_supervised_loss * JE |
) * weight |
).mean(), lambda_n_step_td * td_error_per_sample.abs() + |
lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(), |
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) |
) |
def dqfd_nstep_td_error_with_rescale( |
data: namedtuple, |
gamma: float, |
lambda_n_step_td: float, |
lambda_supervised_loss: float, |
lambda_one_step_td: float, |
margin_function: float, |
nstep: int = 1, |
cum_reward: bool = False, |
value_gamma: Optional[torch.Tensor] = None, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none'), |
trans_fn: Callable = value_transform, |
inv_trans_fn: Callable = value_inv_transform, |
) -> torch.Tensor: |
""" |
Overview: |
Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd |
Arguments: |
- data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data |
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- nstep (:obj:`int`): nstep num, default set to 10 |
Returns: |
- loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor |
- td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ |
+ supervised margin loss, 1-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ |
, 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` |
- new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` |
- next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` |
- is_expert (:obj:`int`) : 0 or 1 |
""" |
q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \ |
is_expert = data |
assert len(action.shape) == 1, action.shape |
if weight is None: |
weight = torch.ones_like(action) |
batch_range = torch.arange(action.shape[0]) |
q_s_a = q[batch_range, action] |
target_q_s_a = next_n_q[batch_range, next_n_action] |
target_q_s_a = inv_trans_fn(target_q_s_a) |
target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] |
target_q_s_a_one_step = inv_trans_fn(target_q_s_a_one_step) |
if cum_reward: |
if value_gamma is None: |
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) |
else: |
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) |
else: |
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) |
target_q_s_a = trans_fn(target_q_s_a) |
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) |
nstep = 1 |
reward = reward[0].unsqueeze(0) |
value_gamma = None |
if cum_reward: |
if value_gamma is None: |
target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step) |
else: |
target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step) |
else: |
target_q_s_a_one_step = nstep_return( |
nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma |
) |
target_q_s_a_one_step = trans_fn(target_q_s_a_one_step) |
td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) |
device = q_s_a.device |
device_cpu = torch.device('cpu') |
l = margin_function * torch.ones_like(q).to(device_cpu) |
l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)) |
JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) |
return ( |
( |
( |
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + |
lambda_supervised_loss * JE |
) * weight |
).mean(), lambda_n_step_td * td_error_per_sample.abs() + |
lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(), |
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) |
) |
qrdqn_nstep_td_data = namedtuple( |
'qrdqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'tau', 'weight'] |
) |
def qrdqn_nstep_td_error( |
data: namedtuple, |
gamma: float, |
nstep: int = 1, |
value_gamma: Optional[torch.Tensor] = None, |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error with in QRDQN |
Arguments: |
- data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- nstep (:obj:`int`): nstep num, default set to 1 |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'reward', 'done'] |
- q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
Examples: |
>>> next_q = torch.randn(4, 3, 3) |
>>> done = torch.randn(4) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> nstep = 3 |
>>> q = torch.randn(4, 3, 3).requires_grad_(True) |
>>> reward = torch.rand(nstep, 4) |
>>> data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, 3, None) |
>>> loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep) |
""" |
q, next_n_q, action, next_n_action, reward, done, tau, weight = data |
assert len(action.shape) == 1, action.shape |
assert len(next_n_action.shape) == 1, next_n_action.shape |
assert len(done.shape) == 1, done.shape |
assert len(q.shape) == 3, q.shape |
assert len(next_n_q.shape) == 3, next_n_q.shape |
assert len(reward.shape) == 2, reward.shape |
if weight is None: |
weight = torch.ones_like(action) |
batch_range = torch.arange(action.shape[0]) |
q_s_a = q[batch_range, action, :].unsqueeze(2) |
target_q_s_a = next_n_q[batch_range, next_n_action, :].unsqueeze(1) |
assert reward.shape[0] == nstep |
reward_factor = torch.ones(nstep).to(reward) |
for i in range(1, nstep): |
reward_factor[i] = gamma * reward_factor[i - 1] |
reward = torch.matmul(reward_factor, reward) |
if value_gamma is None: |
target_q_s_a = reward.unsqueeze(-1).unsqueeze(-1) + (gamma ** nstep |
) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1) |
else: |
target_q_s_a = reward.unsqueeze(-1).unsqueeze( |
-1 |
) + value_gamma.unsqueeze(-1).unsqueeze(-1) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1) |
u = F.smooth_l1_loss(target_q_s_a, q_s_a, reduction="none") |
loss = (u * (tau - (target_q_s_a - q_s_a).detach().le(0.).float()).abs()).sum(-1).mean(1) |
return (loss * weight).mean(), loss |
def q_nstep_sql_td_error( |
data: namedtuple, |
gamma: float, |
alpha: float, |
nstep: int = 1, |
cum_reward: bool = False, |
value_gamma: Optional[torch.Tensor] = None, |
criterion: torch.nn.modules = nn.MSELoss(reduction='none'), |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error for q-learning based algorithm |
Arguments: |
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_sql_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- Alpha (:obj:`float`): A parameter to weight entropy term in a policy equation |
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data |
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target soft_q_value |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- nstep (:obj:`int`): nstep num, default set to 1 |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'reward', 'done'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` |
Examples: |
>>> next_q = torch.randn(4, 3) |
>>> done = torch.randn(4) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> nstep = 3 |
>>> q = torch.randn(4, 3).requires_grad_(True) |
>>> reward = torch.rand(nstep, 4) |
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) |
>>> loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep) |
""" |
q, next_n_q, action, next_n_action, reward, done, weight = data |
assert len(action.shape) == 1, action.shape |
if weight is None: |
weight = torch.ones_like(action) |
batch_range = torch.arange(action.shape[0]) |
q_s_a = q[batch_range, action] |
target_v = alpha * torch.logsumexp( |
next_n_q / alpha, 1 |
) |
target_v[target_v == float("Inf")] = 20 |
target_v[target_v == float("-Inf")] = -20 |
record_target_v = copy.deepcopy(target_v) |
if cum_reward: |
if value_gamma is None: |
target_v = reward + (gamma ** nstep) * target_v * (1 - done) |
else: |
target_v = reward + value_gamma * target_v * (1 - done) |
else: |
target_v = nstep_return(nstep_return_data(reward, target_v, done), gamma, nstep, value_gamma) |
td_error_per_sample = criterion(q_s_a, target_v.detach()) |
return (td_error_per_sample * weight).mean(), td_error_per_sample, record_target_v |
iqn_nstep_td_data = namedtuple( |
'iqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'replay_quantiles', 'weight'] |
) |
def iqn_nstep_td_error( |
data: namedtuple, |
gamma: float, |
nstep: int = 1, |
kappa: float = 1.0, |
value_gamma: Optional[torch.Tensor] = None, |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error with in IQN, \ |
referenced paper Implicit Quantile Networks for Distributional Reinforcement Learning \ |
<https://arxiv.org/pdf/1806.06923.pdf> |
Arguments: |
- data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- nstep (:obj:`int`): nstep num, default set to 1 |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- beta_function (:obj:`Callable`): The risk function |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'reward', 'done'] |
- q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
Examples: |
>>> next_q = torch.randn(3, 4, 3) |
>>> done = torch.randn(4) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> nstep = 3 |
>>> q = torch.randn(3, 4, 3).requires_grad_(True) |
>>> replay_quantile = torch.randn([3, 4, 1]) |
>>> reward = torch.rand(nstep, 4) |
>>> data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None) |
>>> loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep) |
""" |
q, next_n_q, action, next_n_action, reward, done, replay_quantiles, weight = data |
assert len(action.shape) == 1, action.shape |
assert len(next_n_action.shape) == 1, next_n_action.shape |
assert len(done.shape) == 1, done.shape |
assert len(q.shape) == 3, q.shape |
assert len(next_n_q.shape) == 3, next_n_q.shape |
assert len(reward.shape) == 2, reward.shape |
if weight is None: |
weight = torch.ones_like(action) |
batch_size = done.shape[0] |
tau = q.shape[0] |
tau_prime = next_n_q.shape[0] |
action = action.repeat([tau, 1]).unsqueeze(-1) |
next_n_action = next_n_action.repeat([tau_prime, 1]).unsqueeze(-1) |
q_s_a = torch.gather(q, -1, action).permute([1, 0, 2]) |
target_q_s_a = torch.gather(next_n_q, -1, next_n_action).permute([1, 0, 2]) |
assert reward.shape[0] == nstep |
device = torch.device("cuda" if reward.is_cuda else "cpu") |
reward_factor = torch.ones(nstep).to(device) |
for i in range(1, nstep): |
reward_factor[i] = gamma * reward_factor[i - 1] |
reward = torch.matmul(reward_factor, reward) |
if value_gamma is None: |
target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1) |
else: |
target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done |
).unsqueeze(-1) |
target_q_s_a = target_q_s_a.unsqueeze(-1) |
bellman_errors = (target_q_s_a[:, :, None, :] - q_s_a[:, None, :, :]) |
huber_loss = torch.where( |
bellman_errors.abs() <= kappa, 0.5 * bellman_errors ** 2, kappa * (bellman_errors.abs() - 0.5 * kappa) |
) |
replay_quantiles = replay_quantiles.reshape([tau, batch_size, 1]).permute([1, 0, 2]) |
replay_quantiles = replay_quantiles[:, None, :, :].repeat([1, tau_prime, 1, 1]) |
quantile_huber_loss = (torch.abs(replay_quantiles - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa |
loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0] |
return (loss * weight).mean(), loss |
fqf_nstep_td_data = namedtuple( |
'fqf_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'quantiles_hats', 'weight'] |
) |
def fqf_nstep_td_error( |
data: namedtuple, |
gamma: float, |
nstep: int = 1, |
kappa: float = 1.0, |
value_gamma: Optional[torch.Tensor] = None, |
) -> torch.Tensor: |
""" |
Overview: |
Multistep (1 step or n step) td_error with in FQF, \ |
referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \ |
<https://arxiv.org/pdf/1911.02140.pdf> |
Arguments: |
- data (:obj:`fqf_nstep_td_data`): The input data, fqf_nstep_td_data to calculate loss |
- gamma (:obj:`float`): Discount factor |
- nstep (:obj:`int`): nstep num, default set to 1 |
- criterion (:obj:`torch.nn.modules`): Loss function criterion |
- beta_function (:obj:`Callable`): The risk function |
Returns: |
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor |
Shapes: |
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ |
['q', 'next_n_q', 'action', 'reward', 'done'] |
- q (:obj:`torch.FloatTensor`): :math:`(B, tau, N)` i.e. [batch_size, tau, action_dim] |
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, tau', N)` |
- action (:obj:`torch.LongTensor`): :math:`(B, )` |
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) |
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep |
- quantiles_hats (:obj:`torch.FloatTensor`): :math:`(B, tau)` |
Examples: |
>>> next_q = torch.randn(4, 3, 3) |
>>> done = torch.randn(4) |
>>> action = torch.randint(0, 3, size=(4, )) |
>>> next_action = torch.randint(0, 3, size=(4, )) |
>>> nstep = 3 |
>>> q = torch.randn(4, 3, 3).requires_grad_(True) |
>>> quantiles_hats = torch.randn([4, 3]) |
>>> reward = torch.rand(nstep, 4) |
>>> data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None) |
>>> loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep) |
""" |
q, next_n_q, action, next_n_action, reward, done, quantiles_hats, weight = data |
assert len(action.shape) == 1, action.shape |
assert len(next_n_action.shape) == 1, next_n_action.shape |
assert len(done.shape) == 1, done.shape |
assert len(q.shape) == 3, q.shape |
assert len(next_n_q.shape) == 3, next_n_q.shape |
assert len(reward.shape) == 2, reward.shape |
if weight is None: |
weight = torch.ones_like(action) |
batch_size = done.shape[0] |
tau = q.shape[1] |
tau_prime = next_n_q.shape[1] |
q_s_a = evaluate_quantile_at_action(q, action) |
target_q_s_a = evaluate_quantile_at_action(next_n_q, next_n_action) |
assert reward.shape[0] == nstep |
reward_factor = torch.ones(nstep).to(reward.device) |
for i in range(1, nstep): |
reward_factor[i] = gamma * reward_factor[i - 1] |
reward = torch.matmul(reward_factor, reward) |
if value_gamma is None: |
target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1) |
else: |
target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done |
).unsqueeze(-1) |
target_q_s_a = target_q_s_a.unsqueeze(-1) |
bellman_errors = (target_q_s_a.unsqueeze(2) - q_s_a.unsqueeze(1)) |
huber_loss = F.smooth_l1_loss(target_q_s_a.unsqueeze(2), q_s_a.unsqueeze(1), reduction="none") |
quantiles_hats = quantiles_hats[:, None, :, None].repeat([1, tau_prime, 1, 1]) |
quantile_huber_loss = (torch.abs(quantiles_hats - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa |
loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0] |
return (loss * weight).mean(), loss |
def evaluate_quantile_at_action(q_s, actions): |
assert q_s.shape[0] == actions.shape[0] |
batch_size, num_quantiles = q_s.shape[:2] |
action_index = actions[:, None, None].expand(batch_size, num_quantiles, 1) |
q_s_a = q_s.gather(dim=2, index=action_index) |
return q_s_a |
def fqf_calculate_fraction_loss(q_tau_i, q_value, quantiles, actions): |
""" |
Overview: |
Calculate the fraction loss in FQF, \ |
referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \ |
<https://arxiv.org/pdf/1911.02140.pdf> |
Arguments: |
- q_tau_i (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles-1, action_dim)` |
- q_value (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles, action_dim)` |
- quantiles (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles+1)` |
- actions (:obj:`torch.LongTensor`): :math:`(batch_size, )` |
Returns: |
- fraction_loss (:obj:`torch.Tensor`): fraction loss, 0-dim tensor |
""" |
assert q_value.requires_grad |
batch_size = q_value.shape[0] |
num_quantiles = q_value.shape[1] |
with torch.no_grad(): |
sa_quantiles = evaluate_quantile_at_action(q_tau_i, actions) |
assert sa_quantiles.shape == (batch_size, num_quantiles - 1, 1) |
q_s_a_hats = evaluate_quantile_at_action(q_value, actions) |
assert q_s_a_hats.shape == (batch_size, num_quantiles, 1) |
assert not q_s_a_hats.requires_grad |
values_1 = sa_quantiles - q_s_a_hats[:, :-1] |
signs_1 = sa_quantiles > torch.cat([q_s_a_hats[:, :1], sa_quantiles[:, :-1]], dim=1) |
assert values_1.shape == signs_1.shape |
values_2 = sa_quantiles - q_s_a_hats[:, 1:] |
signs_2 = sa_quantiles < torch.cat([sa_quantiles[:, 1:], q_s_a_hats[:, -1:]], dim=1) |
assert values_2.shape == signs_2.shape |
gradient_of_taus = (torch.where(signs_1, values_1, -values_1) + |
torch.where(signs_2, values_2, -values_2)).view(batch_size, num_quantiles - 1) |
assert not gradient_of_taus.requires_grad |
assert gradient_of_taus.shape == quantiles[:, 1:-1].shape |
fraction_loss = (gradient_of_taus * quantiles[:, 1:-1]).sum(dim=1).mean() |
return fraction_loss |
td_lambda_data = namedtuple('td_lambda_data', ['value', 'reward', 'weight']) |
def shape_fn_td_lambda(args, kwargs): |
r""" |
Overview: |
Return td_lambda shape for hpc |
Returns: |
shape: [T, B] |
""" |
if len(args) <= 0: |
tmp = kwargs['data'].reward.shape[0] |
else: |
tmp = args[0].reward.shape |
return tmp |
@hpc_wrapper( |
shape_fn=shape_fn_td_lambda, |
namedtuple_data=True, |
include_args=[0, 1, 2], |
include_kwargs=['data', 'gamma', 'lambda_'] |
) |
def td_lambda_error(data: namedtuple, gamma: float = 0.9, lambda_: float = 0.8) -> torch.Tensor: |
""" |
Overview: |
Computing TD(lambda) loss given constant gamma and lambda. |
There is no special handling for terminal state value, |
if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal |
(*including the terminal state*, values[terminal] should also be 0) |
Arguments: |
- data (:obj:`namedtuple`): td_lambda input data with fields ['value', 'reward', 'weight'] |
- gamma (:obj:`float`): Constant discount factor gamma, should be in [0, 1], defaults to 0.9 |
- lambda (:obj:`float`): Constant lambda, should be in [0, 1], defaults to 0.8 |
Returns: |
- loss (:obj:`torch.Tensor`): Computed MSE loss, averaged over the batch |
Shapes: |
- value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`, where T is trajectory length and B is batch,\ |
which is the estimation of the state value at step 0 to T |
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, the returns from time step 0 to T-1 |
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight |
- loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor |
Examples: |
>>> T, B = 8, 4 |
>>> value = torch.randn(T + 1, B).requires_grad_(True) |
>>> reward = torch.rand(T, B) |
>>> loss = td_lambda_error(td_lambda_data(value, reward, None)) |
""" |
value, reward, weight = data |
if weight is None: |
weight = torch.ones_like(reward) |
with torch.no_grad(): |
return_ = generalized_lambda_returns(value, reward, gamma, lambda_) |
loss = 0.5 * (F.mse_loss(return_, value[:-1], reduction='none') * weight).mean() |
return loss |
def generalized_lambda_returns( |
bootstrap_values: torch.Tensor, |
rewards: torch.Tensor, |
gammas: float, |
lambda_: float, |
done: Optional[torch.Tensor] = None |
) -> torch.Tensor: |
r""" |
Overview: |
Functional equivalent to trfl.value_ops.generalized_lambda_returns |
https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74 |
Passing in a number instead of tensor to make the value constant for all samples in batch |
Arguments: |
- bootstrap_values (:obj:`torch.Tensor` or :obj:`float`): |
estimation of the value at step 0 to *T*, of size [T_traj+1, batchsize] |
- rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize] |
- gammas (:obj:`torch.Tensor` or :obj:`float`): |
Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize] |
- lambda (:obj:`torch.Tensor` or :obj:`float`): Determining the mix of bootstrapping |
vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize] |
- done (:obj:`torch.Tensor` or :obj:`float`): |
Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize] |
Returns: |
- return (:obj:`torch.Tensor`): Computed lambda return value |
for each state from 0 to T-1, of size [T_traj, batchsize] |
""" |
if not isinstance(gammas, torch.Tensor): |
gammas = gammas * torch.ones_like(rewards) |
if not isinstance(lambda_, torch.Tensor): |
lambda_ = lambda_ * torch.ones_like(rewards) |
bootstrap_values_tp1 = bootstrap_values[1:, :] |
return multistep_forward_view(bootstrap_values_tp1, rewards, gammas, lambda_, done) |
def multistep_forward_view( |
bootstrap_values: torch.Tensor, |
rewards: torch.Tensor, |
gammas: float, |
lambda_: float, |
done: Optional[torch.Tensor] = None |
) -> torch.Tensor: |
r""" |
Overview: |
Same as trfl.sequence_ops.multistep_forward_view |
Implementing (12.18) in Sutton & Barto |
``` |
result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T] |
for t in 0...T-2 : |
result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1]) |
``` |
Assuming the first dim of input tensors correspond to the index in batch |
Arguments: |
- bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize] |
- rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize] |
- gammas (:obj:`torch.Tensor`): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize] |
- lambda (:obj:`torch.Tensor`): Determining the mix of bootstrapping vs further accumulation of \ |
multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored \ |
and effectively set to 0, as there is no information about future rewards. |
- done (:obj:`torch.Tensor` or :obj:`float`): |
Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize] |
Returns: |
- ret (:obj:`torch.Tensor`): Computed lambda return value \ |
for each state from 0 to T-1, of size [T_traj, batchsize] |
""" |
result = torch.empty_like(rewards) |
if done is None: |
done = torch.zeros_like(rewards) |
result[-1, :] = rewards[-1, :] + (1 - done[-1, :]) * gammas[-1, :] * bootstrap_values[-1, :] |
discounts = gammas * lambda_ |
for t in reversed(range(rewards.size()[0] - 1)): |
result[t, :] = rewards[t, :] + (1 - done[t, :]) * \ |
( |
discounts[t, :] * result[t + 1, :] + |
(gammas[t, :] - discounts[t, :]) * bootstrap_values[t, :] |
) |
return result |