zjowowen's picture
init space
079c32c
import copy
import numpy as np
from collections import namedtuple
from typing import Union, Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
from ding.hpc_rl import hpc_wrapper
from ding.rl_utils.value_rescale import value_transform, value_inv_transform
from ding.torch_utils import to_tensor
q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'])
def discount_cumsum(x, gamma: float = 1.0) -> np.ndarray:
assert abs(gamma - 1.) < 1e-5, "gamma equals to 1.0 in original decision transformer paper"
disc_cumsum = np.zeros_like(x)
disc_cumsum[-1] = x[-1]
for t in reversed(range(x.shape[0] - 1)):
disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1]
return disc_cumsum
def q_1step_td_error(
data: namedtuple,
gamma: float,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
"""
Overview:
1 step td_error, support single agent case and multi agent case.
Arguments:
- data (:obj:`q_1step_td_data`): The input data, q_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error
Shapes:
- data (:obj:`q_1step_td_data`): the q_1step_td_data containing\
['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- next_act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`( , B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> action_dim = 4
>>> data = q_1step_td_data(
>>> q=torch.randn(3, action_dim),
>>> next_q=torch.randn(3, action_dim),
>>> act=torch.randint(0, action_dim, (3,)),
>>> next_act=torch.randint(0, action_dim, (3,)),
>>> reward=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)).bool(),
>>> weight=torch.ones(3),
>>> )
>>> loss = q_1step_td_error(data, 0.99)
"""
q, next_q, act, next_act, reward, done, weight = data
assert len(act.shape) == 1, act.shape
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
if weight is None:
weight = torch.ones_like(reward)
q_s_a = q[batch_range, act]
target_q_s_a = next_q[batch_range, next_act]
target_q_s_a = gamma * (1 - done) * target_q_s_a + reward
return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean()
m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'])
def m_q_1step_td_error(
data: namedtuple,
gamma: float,
tau: float,
alpha: float,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
"""
Overview:
Munchausen td_error for DQN algorithm, support 1 step td error.
Arguments:
- data (:obj:`m_q_1step_td_data`): The input data, m_q_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- tau (:obj:`float`): Entropy factor for Munchausen DQN
- alpha (:obj:`float`): Discount factor for Munchausen term
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
Shapes:
- data (:obj:`m_q_1step_td_data`): the m_q_1step_td_data containing\
['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- target_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`( , B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> action_dim = 4
>>> data = m_q_1step_td_data(
>>> q=torch.randn(3, action_dim),
>>> target_q=torch.randn(3, action_dim),
>>> next_q=torch.randn(3, action_dim),
>>> act=torch.randint(0, action_dim, (3,)),
>>> reward=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)),
>>> weight=torch.ones(3),
>>> )
>>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01)
"""
q, target_q, next_q, act, reward, done, weight = data
lower_bound = -1
assert len(act.shape) == 1, act.shape
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
if weight is None:
weight = torch.ones_like(reward)
q_s_a = q[batch_range, act]
# calculate muchausen addon
# replay_log_policy
target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1)
logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1)
log_pi = target_q - target_v_s - tau * logsum
act_get = act.unsqueeze(-1)
# same to the last second tau_log_pi_a
munchausen_addon = log_pi.gather(1, act_get)
muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1)
# replay_next_log_policy
target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1)
logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1)
tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next
# do stable softmax == replay_next_policy
pi_target = F.softmax((next_q - target_v_s_next) / tau)
target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1)
target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a
td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1)
# calculate action_gap and clipfrac
with torch.no_grad():
top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0]
action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean()
clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound)
clipfrac = torch.as_tensor(clipped).float()
return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac
q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight'])
def q_v_1step_td_error(
data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none')
) -> torch.Tensor:
# we will use this function in discrete sac algorithm to calculate td error between q and v value.
"""
Overview:
td_error between q and v value for SAC algorithm, support 1 step td error.
Arguments:
- data (:obj:`q_v_1step_td_data`): The input data, q_v_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
Shapes:
- data (:obj:`q_v_1step_td_data`): the q_v_1step_td_data containing\
['q', 'v', 'act', 'reward', 'done', 'weight']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- v (:obj:`torch.FloatTensor`): :math:`(B, )`
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`( , B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> action_dim = 4
>>> data = q_v_1step_td_data(
>>> q=torch.randn(3, action_dim),
>>> v=torch.randn(3),
>>> act=torch.randint(0, action_dim, (3,)),
>>> reward=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)),
>>> weight=torch.ones(3),
>>> )
>>> loss = q_v_1step_td_error(data, 0.99)
"""
q, v, act, reward, done, weight = data
if len(act.shape) == 1:
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
if weight is None:
weight = torch.ones_like(reward)
q_s_a = q[batch_range, act]
target_q_s_a = gamma * (1 - done) * v + reward
else:
assert len(reward.shape) == 1, reward.shape
batch_range = torch.arange(act.shape[0])
actor_range = torch.arange(act.shape[1])
batch_actor_range = torch.arange(act.shape[0] * act.shape[1])
if weight is None:
weight = torch.ones_like(act)
temp_q = q.reshape(act.shape[0] * act.shape[1], -1)
temp_act = act.reshape(act.shape[0] * act.shape[1])
q_s_a = temp_q[batch_actor_range, temp_act]
q_s_a = q_s_a.reshape(act.shape[0], act.shape[1])
target_q_s_a = gamma * (1 - done).unsqueeze(1) * v + reward.unsqueeze(1)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
def view_similar(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
size = list(x.shape) + [1 for _ in range(len(target.shape) - len(x.shape))]
return x.view(*size)
nstep_return_data = namedtuple('nstep_return_data', ['reward', 'next_value', 'done'])
def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_gamma: Optional[torch.Tensor] = None):
'''
Overview:
Calculate nstep return for DQN algorithm, support single agent case and multi agent case.
Arguments:
- data (:obj:`nstep_return_data`): The input data, nstep_return_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num
- value_gamma (:obj:`torch.Tensor`): Discount factor for value
Returns:
- return (:obj:`torch.Tensor`): nstep return
Shapes:
- data (:obj:`nstep_return_data`): the nstep_return_data containing\
['reward', 'next_value', 'done']
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- next_value (:obj:`torch.FloatTensor`): :math:`(, B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> data = nstep_return_data(
>>> reward=torch.randn(3, 3),
>>> next_value=torch.randn(3),
>>> done=torch.randint(0, 2, (3,)),
>>> )
>>> loss = nstep_return(data, 0.99, 3)
'''
reward, next_value, done = data
assert reward.shape[0] == nstep
device = reward.device
if isinstance(gamma, float):
reward_factor = torch.ones(nstep).to(device)
for i in range(1, nstep):
reward_factor[i] = gamma * reward_factor[i - 1]
reward_factor = view_similar(reward_factor, reward)
return_tmp = reward.mul(reward_factor).sum(0)
if value_gamma is None:
return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done)
else:
return_ = return_tmp + value_gamma * next_value * (1 - done)
elif isinstance(gamma, list):
# if gamma is list, for NGU policy case
reward_factor = torch.ones([nstep + 1, done.shape[0]]).to(device)
for i in range(1, nstep + 1):
reward_factor[i] = torch.stack(gamma, dim=0).to(device) * reward_factor[i - 1]
reward_factor = view_similar(reward_factor, reward)
return_tmp = reward.mul(reward_factor[:nstep]).sum(0)
return_ = return_tmp + reward_factor[nstep] * next_value * (1 - done)
else:
raise TypeError("The type of gamma should be float or list")
return return_
dist_1step_td_data = namedtuple(
'dist_1step_td_data', ['dist', 'next_dist', 'act', 'next_act', 'reward', 'done', 'weight']
)
def dist_1step_td_error(
data: namedtuple,
gamma: float,
v_min: float,
v_max: float,
n_atom: int,
) -> torch.Tensor:
"""
Overview:
1 step td_error for distributed q-learning based algorithm
Arguments:
- data (:obj:`dist_1step_td_data`): The input data, dist_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- v_min (:obj:`float`): The min value of support
- v_max (:obj:`float`): The max value of support
- n_atom (:obj:`int`): The num of atom
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`dist_1step_td_data`): the dist_1step_td_data containing\
['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight']
- dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom]
- next_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)`
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- next_act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(, B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True)
>>> next_dist = torch.randn(4, 3, 51).abs()
>>> act = torch.randint(0, 3, (4,))
>>> next_act = torch.randint(0, 3, (4,))
>>> reward = torch.randn(4)
>>> done = torch.randint(0, 2, (4,))
>>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None)
>>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51)
"""
dist, next_dist, act, next_act, reward, done, weight = data
device = reward.device
assert len(reward.shape) == 1, reward.shape
support = torch.linspace(v_min, v_max, n_atom).to(device)
delta_z = (v_max - v_min) / (n_atom - 1)
if len(act.shape) == 1:
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)
batch_size = act.shape[0]
batch_range = torch.arange(batch_size)
if weight is None:
weight = torch.ones_like(reward)
next_dist = next_dist[batch_range, next_act].detach()
else:
reward = reward.unsqueeze(-1).repeat(1, act.shape[1])
done = done.unsqueeze(-1).repeat(1, act.shape[1])
batch_size = act.shape[0] * act.shape[1]
batch_range = torch.arange(act.shape[0] * act.shape[1])
action_dim = dist.shape[2]
dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
reward = reward.reshape(act.shape[0] * act.shape[1], -1)
done = done.reshape(act.shape[0] * act.shape[1], -1)
next_dist = next_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
next_act = next_act.reshape(act.shape[0] * act.shape[1])
next_dist = next_dist[batch_range, next_act].detach()
next_dist = next_dist.reshape(act.shape[0] * act.shape[1], -1)
act = act.reshape(act.shape[0] * act.shape[1])
if weight is None:
weight = torch.ones_like(reward)
target_z = reward + (1 - done) * gamma * support
target_z = target_z.clamp(min=v_min, max=v_max)
b = (target_z - v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
# Fix disappearing probability mass when l = b = u (b is int)
l[(u > 0) * (l == u)] -= 1
u[(l < (n_atom - 1)) * (l == u)] += 1
proj_dist = torch.zeros_like(next_dist)
offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size,
n_atom).long().to(device)
proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))
log_p = torch.log(dist[batch_range, act])
loss = -(log_p * proj_dist * weight).sum(-1).mean()
return loss
dist_nstep_td_data = namedtuple(
'dist_1step_td_data', ['dist', 'next_n_dist', 'act', 'next_n_act', 'reward', 'done', 'weight']
)
def shape_fn_dntd(args, kwargs):
r"""
Overview:
Return dntd shape for hpc
Returns:
shape: [T, B, N, n_atom]
"""
if len(args) <= 0:
tmp = [kwargs['data'].reward.shape[0]]
tmp.extend(list(kwargs['data'].dist.shape))
else:
tmp = [args[0].reward.shape[0]]
tmp.extend(list(args[0].dist.shape))
return tmp
@hpc_wrapper(
shape_fn=shape_fn_dntd,
namedtuple_data=True,
include_args=[0, 1, 2, 3],
include_kwargs=['data', 'gamma', 'v_min', 'v_max']
)
def dist_nstep_td_error(
data: namedtuple,
gamma: float,
v_min: float,
v_max: float,
n_atom: int,
nstep: int = 1,
value_gamma: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single\
agent case and multi agent case.
Arguments:
- data (:obj:`dist_nstep_td_data`): The input data, dist_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`dist_nstep_td_data`): the dist_nstep_td_data containing\
['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight']
- dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom]
- next_n_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)`
- act (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_act (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True)
>>> next_n_dist = torch.randn(4, 3, 51).abs()
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> reward = torch.randn(5, 4)
>>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None)
>>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5)
"""
dist, next_n_dist, act, next_n_act, reward, done, weight = data
device = reward.device
reward_factor = torch.ones(nstep).to(device)
for i in range(1, nstep):
reward_factor[i] = gamma * reward_factor[i - 1]
reward = torch.matmul(reward_factor, reward)
support = torch.linspace(v_min, v_max, n_atom).to(device)
delta_z = (v_max - v_min) / (n_atom - 1)
if len(act.shape) == 1:
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)
batch_size = act.shape[0]
batch_range = torch.arange(batch_size)
if weight is None:
weight = torch.ones_like(reward)
elif isinstance(weight, float):
weight = torch.tensor(weight)
next_n_dist = next_n_dist[batch_range, next_n_act].detach()
else:
reward = reward.unsqueeze(-1).repeat(1, act.shape[1])
done = done.unsqueeze(-1).repeat(1, act.shape[1])
batch_size = act.shape[0] * act.shape[1]
batch_range = torch.arange(act.shape[0] * act.shape[1])
action_dim = dist.shape[2]
dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
reward = reward.reshape(act.shape[0] * act.shape[1], -1)
done = done.reshape(act.shape[0] * act.shape[1], -1)
next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
next_n_act = next_n_act.reshape(act.shape[0] * act.shape[1])
next_n_dist = next_n_dist[batch_range, next_n_act].detach()
next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], -1)
act = act.reshape(act.shape[0] * act.shape[1])
if weight is None:
weight = torch.ones_like(reward)
elif isinstance(weight, float):
weight = torch.tensor(weight)
if value_gamma is None:
target_z = reward + (1 - done) * (gamma ** nstep) * support
elif isinstance(value_gamma, float):
value_gamma = torch.tensor(value_gamma).unsqueeze(-1)
target_z = reward + (1 - done) * value_gamma * support
else:
value_gamma = value_gamma.unsqueeze(-1)
target_z = reward + (1 - done) * value_gamma * support
target_z = target_z.clamp(min=v_min, max=v_max)
b = (target_z - v_min) / delta_z
l = b.floor().long()
u = b.ceil().long()
# Fix disappearing probability mass when l = b = u (b is int)
l[(u > 0) * (l == u)] -= 1
u[(l < (n_atom - 1)) * (l == u)] += 1
proj_dist = torch.zeros_like(next_n_dist)
offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size,
n_atom).long().to(device)
proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_n_dist * (u.float() - b)).view(-1))
proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_n_dist * (b - l.float())).view(-1))
assert (dist[batch_range, act] > 0.0).all(), ("dist act", dist[batch_range, act], "dist:", dist)
log_p = torch.log(dist[batch_range, act])
if len(weight.shape) == 1:
weight = weight.unsqueeze(-1)
td_error_per_sample = -(log_p * proj_dist).sum(-1)
loss = -(log_p * proj_dist * weight).sum(-1).mean()
return loss, td_error_per_sample
v_1step_td_data = namedtuple('v_1step_td_data', ['v', 'next_v', 'reward', 'done', 'weight'])
def v_1step_td_error(
data: namedtuple,
gamma: float,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
'''
Overview:
1 step td_error for distributed value based algorithm
Arguments:
- data (:obj:`v_1step_td_data`): The input data, v_1step_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- criterion (:obj:`torch.nn.modules`): Loss function criterion
Returns:
- loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
Shapes:
- data (:obj:`v_1step_td_data`): the v_1step_td_data containing\
['v', 'next_v', 'reward', 'done', 'weight']
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(, B)`
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
Examples:
>>> v = torch.randn(5).requires_grad_(True)
>>> next_v = torch.randn(5)
>>> reward = torch.rand(5)
>>> done = torch.zeros(5)
>>> data = v_1step_td_data(v, next_v, reward, done, None)
>>> loss, td_error_per_sample = v_1step_td_error(data, 0.99)
'''
v, next_v, reward, done, weight = data
if weight is None:
weight = torch.ones_like(v)
if len(v.shape) == len(reward.shape):
if done is not None:
target_v = gamma * (1 - done) * next_v + reward
else:
target_v = gamma * next_v + reward
else:
if done is not None:
target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1)
else:
target_v = gamma * next_v + reward.unsqueeze(1)
td_error_per_sample = criterion(v, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
v_nstep_td_data = namedtuple('v_nstep_td_data', ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'])
def v_nstep_td_error(
data: namedtuple,
gamma: float,
nstep: int = 1,
criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
) -> torch.Tensor:
r"""
Overview:
Multistep (n step) td_error for distributed value based algorithm
Arguments:
- data (:obj:`dist_nstep_td_data`): The input data, v_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing\
['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']
- v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
- next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
- value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step\
we use value_gamma as the gamma discount value for next_v rather than gamma**n_step
Examples:
>>> v = torch.randn(5).requires_grad_(True)
>>> next_v = torch.randn(5)
>>> reward = torch.rand(5, 5)
>>> done = torch.zeros(5)
>>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99)
>>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5)
"""
v, next_n_v, reward, done, weight, value_gamma = data
if weight is None:
weight = torch.ones_like(v)
target_v = nstep_return(nstep_return_data(reward, next_n_v, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(v, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
q_nstep_td_data = namedtuple(
'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight']
)
dqfd_nstep_td_data = namedtuple(
'dqfd_nstep_td_data', [
'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_one_step', 'weight', 'new_n_q_one_step',
'next_n_action_one_step', 'is_expert'
]
)
def shape_fn_qntd(args, kwargs):
r"""
Overview:
Return qntd shape for hpc
Returns:
shape: [T, B, N]
"""
if len(args) <= 0:
tmp = [kwargs['data'].reward.shape[0]]
tmp.extend(list(kwargs['data'].q.shape))
else:
tmp = [args[0].reward.shape[0]]
tmp.extend(list(args[0].q.shape))
return tmp
@hpc_wrapper(shape_fn=shape_fn_qntd, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma'])
def q_nstep_td_error(
data: namedtuple,
gamma: Union[float, list],
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for q-learning based algorithm
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
Examples:
>>> next_q = torch.randn(4, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep =3
>>> q = torch.randn(4, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
>>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
if weight is None:
weight = torch.ones_like(reward)
if len(action.shape) == 1: # single agent case
action = action.unsqueeze(-1)
elif len(action.shape) > 1: # MARL case
reward = reward.unsqueeze(-1)
weight = weight.unsqueeze(-1)
done = done.unsqueeze(-1)
if value_gamma is not None:
value_gamma = value_gamma.unsqueeze(-1)
q_s_a = q.gather(-1, action).squeeze(-1)
target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1)
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
def bdq_nstep_td_error(
data: namedtuple,
gamma: Union[float, list],
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \
Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946.
In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \
calculation method of n-step, i.e., TD-error:
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, D)`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, D)`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
Examples:
>>> action_per_branch = 3
>>> next_q = torch.randn(8, 6, action_per_branch)
>>> done = torch.randn(8)
>>> action = torch.randint(0, action_per_branch, size=(8, 6))
>>> next_action = torch.randint(0, action_per_branch, size=(8, 6))
>>> nstep =3
>>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True)
>>> reward = torch.rand(nstep, 8)
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
>>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
if weight is None:
weight = torch.ones_like(reward)
reward = reward.unsqueeze(-1)
done = done.unsqueeze(-1)
if value_gamma is not None:
value_gamma = value_gamma.unsqueeze(-1)
q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1)
target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1)
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
td_error_per_sample = td_error_per_sample.mean(-1)
return (td_error_per_sample * weight).mean(), td_error_per_sample
def shape_fn_qntd_rescale(args, kwargs):
r"""
Overview:
Return qntd_rescale shape for hpc
Returns:
shape: [T, B, N]
"""
if len(args) <= 0:
tmp = [kwargs['data'].reward.shape[0]]
tmp.extend(list(kwargs['data'].q.shape))
else:
tmp = [args[0].reward.shape[0]]
tmp.extend(list(args[0].q.shape))
return tmp
@hpc_wrapper(
shape_fn=shape_fn_qntd_rescale, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma']
)
def q_nstep_td_error_with_rescale(
data: namedtuple,
gamma: Union[float, list],
nstep: int = 1,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
trans_fn: Callable = value_transform,
inv_trans_fn: Callable = value_inv_transform,
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error with value rescaling
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- trans_fn (:obj:`Callable`): Value transfrom function, default to value_transform\
(refer to rl_utils/value_rescale.py)
- inv_trans_fn (:obj:`Callable`): Value inverse transfrom function, default to value_inv_transform\
(refer to rl_utils/value_rescale.py)
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> next_q = torch.randn(4, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep =3
>>> q = torch.randn(4, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
>>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
target_q_s_a = next_n_q[batch_range, next_n_action]
target_q_s_a = inv_trans_fn(target_q_s_a)
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
target_q_s_a = trans_fn(target_q_s_a)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample
def dqfd_nstep_td_error(
data: namedtuple,
gamma: float,
lambda_n_step_td: float,
lambda_supervised_loss: float,
margin_function: float,
lambda_one_step_td: float = 1.,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd
Arguments:
- data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss
- gamma (:obj:`float`): discount factor
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 10
Returns:
- loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\
+ supervised margin loss, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\
, 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
- new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`
- next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`
- is_expert (:obj:`int`) : 0 or 1
Examples:
>>> next_q = torch.randn(4, 3)
>>> done = torch.randn(4)
>>> done_1 = torch.randn(4)
>>> next_q_one_step = torch.randn(4, 3)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> next_action_one_step = torch.randint(0, 3, size=(4, ))
>>> is_expert = torch.ones((4))
>>> nstep = 3
>>> q = torch.randn(4, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = dqfd_nstep_td_data(
>>> q, next_q, action, next_action, reward, done, done_1, None,
>>> next_q_one_step, next_action_one_step, is_expert
>>> )
>>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error(
>>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1,
>>> margin_function=0.8, nstep=nstep
>>> )
"""
q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \
is_expert = data # set is_expert flag(expert 1, agent 0)
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
target_q_s_a = next_n_q[batch_range, next_n_action]
target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]
# calculate n-step TD-loss
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
# calculate 1-step TD-loss
nstep = 1
reward = reward[0].unsqueeze(0) # get the one-step reward
value_gamma = None
if cum_reward:
if value_gamma is None:
target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step)
else:
target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step)
else:
target_q_s_a_one_step = nstep_return(
nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma
)
td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())
device = q_s_a.device
device_cpu = torch.device('cpu')
# calculate the supervised loss
l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, )
l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu))
# along the first dimension. for the index of the action, fill the corresponding position in l with 0
JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)
return (
(
(
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE
) * weight
).mean(), lambda_n_step_td * td_error_per_sample.abs() +
lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
)
def dqfd_nstep_td_error_with_rescale(
data: namedtuple,
gamma: float,
lambda_n_step_td: float,
lambda_supervised_loss: float,
lambda_one_step_td: float,
margin_function: float,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
trans_fn: Callable = value_transform,
inv_trans_fn: Callable = value_inv_transform,
) -> torch.Tensor:
"""
Overview:
Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd
Arguments:
- data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 10
Returns:
- loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\
+ supervised margin loss, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\
, 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
- new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`
- next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`
- is_expert (:obj:`int`) : 0 or 1
"""
q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \
is_expert = data # set is_expert flag(expert 1, agent 0)
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
target_q_s_a = next_n_q[batch_range, next_n_action]
target_q_s_a = inv_trans_fn(target_q_s_a) # rescale
target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]
target_q_s_a_one_step = inv_trans_fn(target_q_s_a_one_step) # rescale
# calculate n-step TD-loss
if cum_reward:
if value_gamma is None:
target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
else:
target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
else:
# to use value_gamma in n-step TD-loss
target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
target_q_s_a = trans_fn(target_q_s_a) # rescale
td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
# calculate 1-step TD-loss
nstep = 1
reward = reward[0].unsqueeze(0) # get the one-step reward
value_gamma = None # This is very important, to use gamma in 1-step TD-loss
if cum_reward:
if value_gamma is None:
target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step)
else:
target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step)
else:
target_q_s_a_one_step = nstep_return(
nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma
)
target_q_s_a_one_step = trans_fn(target_q_s_a_one_step) # rescale
td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())
device = q_s_a.device
device_cpu = torch.device('cpu')
# calculate the supervised loss
l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, )
l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu))
# along the first dimension. for the index of the action, fill the corresponding position in l with 0
JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)
return (
(
(
lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
lambda_supervised_loss * JE
) * weight
).mean(), lambda_n_step_td * td_error_per_sample.abs() +
lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
(td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
)
qrdqn_nstep_td_data = namedtuple(
'qrdqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'tau', 'weight']
)
def qrdqn_nstep_td_error(
data: namedtuple,
gamma: float,
nstep: int = 1,
value_gamma: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error with in QRDQN
Arguments:
- data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> next_q = torch.randn(4, 3, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep = 3
>>> q = torch.randn(4, 3, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, 3, None)
>>> loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, tau, weight = data
assert len(action.shape) == 1, action.shape
assert len(next_n_action.shape) == 1, next_n_action.shape
assert len(done.shape) == 1, done.shape
assert len(q.shape) == 3, q.shape
assert len(next_n_q.shape) == 3, next_n_q.shape
assert len(reward.shape) == 2, reward.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
# shape: batch_size x num x 1
q_s_a = q[batch_range, action, :].unsqueeze(2)
# shape: batch_size x 1 x num
target_q_s_a = next_n_q[batch_range, next_n_action, :].unsqueeze(1)
assert reward.shape[0] == nstep
reward_factor = torch.ones(nstep).to(reward)
for i in range(1, nstep):
reward_factor[i] = gamma * reward_factor[i - 1]
# shape: batch_size
reward = torch.matmul(reward_factor, reward)
# shape: batch_size x 1 x num
if value_gamma is None:
target_q_s_a = reward.unsqueeze(-1).unsqueeze(-1) + (gamma ** nstep
) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1)
else:
target_q_s_a = reward.unsqueeze(-1).unsqueeze(
-1
) + value_gamma.unsqueeze(-1).unsqueeze(-1) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1)
# shape: batch_size x num x num
u = F.smooth_l1_loss(target_q_s_a, q_s_a, reduction="none")
# shape: batch_size
loss = (u * (tau - (target_q_s_a - q_s_a).detach().le(0.).float()).abs()).sum(-1).mean(1)
return (loss * weight).mean(), loss
def q_nstep_sql_td_error(
data: namedtuple,
gamma: float,
alpha: float,
nstep: int = 1,
cum_reward: bool = False,
value_gamma: Optional[torch.Tensor] = None,
criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error for q-learning based algorithm
Arguments:
- data (:obj:`q_nstep_td_data`): The input data, q_nstep_sql_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- Alpha (:obj:`float`): A parameter to weight entropy term in a policy equation
- cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
- value_gamma (:obj:`torch.Tensor`): Gamma discount value for target soft_q_value
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- nstep (:obj:`int`): nstep num, default set to 1
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
- td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
Examples:
>>> next_q = torch.randn(4, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep = 3
>>> q = torch.randn(4, 3).requires_grad_(True)
>>> reward = torch.rand(nstep, 4)
>>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
>>> loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, weight = data
assert len(action.shape) == 1, action.shape
if weight is None:
weight = torch.ones_like(action)
batch_range = torch.arange(action.shape[0])
q_s_a = q[batch_range, action]
# target_q_s_a = next_n_q[batch_range, next_n_action]
target_v = alpha * torch.logsumexp(
next_n_q / alpha, 1
) # target_v = alpha * torch.log(torch.sum(torch.exp(next_n_q / alpha), 1))
target_v[target_v == float("Inf")] = 20
target_v[target_v == float("-Inf")] = -20
# For an appropriate hyper-parameter alpha, these hardcodes can be removed.
# However, algorithms may face the danger of explosion for other alphas.
# The hardcodes above are to prevent this situation from happening
record_target_v = copy.deepcopy(target_v)
# print(target_v)
if cum_reward:
if value_gamma is None:
target_v = reward + (gamma ** nstep) * target_v * (1 - done)
else:
target_v = reward + value_gamma * target_v * (1 - done)
else:
target_v = nstep_return(nstep_return_data(reward, target_v, done), gamma, nstep, value_gamma)
td_error_per_sample = criterion(q_s_a, target_v.detach())
return (td_error_per_sample * weight).mean(), td_error_per_sample, record_target_v
iqn_nstep_td_data = namedtuple(
'iqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'replay_quantiles', 'weight']
)
def iqn_nstep_td_error(
data: namedtuple,
gamma: float,
nstep: int = 1,
kappa: float = 1.0,
value_gamma: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error with in IQN, \
referenced paper Implicit Quantile Networks for Distributional Reinforcement Learning \
<https://arxiv.org/pdf/1806.06923.pdf>
Arguments:
- data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- beta_function (:obj:`Callable`): The risk function
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
Examples:
>>> next_q = torch.randn(3, 4, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep = 3
>>> q = torch.randn(3, 4, 3).requires_grad_(True)
>>> replay_quantile = torch.randn([3, 4, 1])
>>> reward = torch.rand(nstep, 4)
>>> data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None)
>>> loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, replay_quantiles, weight = data
assert len(action.shape) == 1, action.shape
assert len(next_n_action.shape) == 1, next_n_action.shape
assert len(done.shape) == 1, done.shape
assert len(q.shape) == 3, q.shape
assert len(next_n_q.shape) == 3, next_n_q.shape
assert len(reward.shape) == 2, reward.shape
if weight is None:
weight = torch.ones_like(action)
batch_size = done.shape[0]
tau = q.shape[0]
tau_prime = next_n_q.shape[0]
action = action.repeat([tau, 1]).unsqueeze(-1)
next_n_action = next_n_action.repeat([tau_prime, 1]).unsqueeze(-1)
# shape: batch_size x tau x a
q_s_a = torch.gather(q, -1, action).permute([1, 0, 2])
# shape: batch_size x tau_prim x 1
target_q_s_a = torch.gather(next_n_q, -1, next_n_action).permute([1, 0, 2])
assert reward.shape[0] == nstep
device = torch.device("cuda" if reward.is_cuda else "cpu")
reward_factor = torch.ones(nstep).to(device)
for i in range(1, nstep):
reward_factor[i] = gamma * reward_factor[i - 1]
reward = torch.matmul(reward_factor, reward)
if value_gamma is None:
target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1)
else:
target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done
).unsqueeze(-1)
target_q_s_a = target_q_s_a.unsqueeze(-1)
# shape: batch_size x tau' x tau x 1.
bellman_errors = (target_q_s_a[:, :, None, :] - q_s_a[:, None, :, :])
# The huber loss (see Section 2.3 of the paper) is defined via two cases:
huber_loss = torch.where(
bellman_errors.abs() <= kappa, 0.5 * bellman_errors ** 2, kappa * (bellman_errors.abs() - 0.5 * kappa)
)
# Reshape replay_quantiles to batch_size x num_tau_samples x 1
replay_quantiles = replay_quantiles.reshape([tau, batch_size, 1]).permute([1, 0, 2])
# shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.
replay_quantiles = replay_quantiles[:, None, :, :].repeat([1, tau_prime, 1, 1])
# shape: batch_size x tau_prime x tau x 1.
quantile_huber_loss = (torch.abs(replay_quantiles - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa
# shape: batch_size
loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0]
return (loss * weight).mean(), loss
fqf_nstep_td_data = namedtuple(
'fqf_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'quantiles_hats', 'weight']
)
def fqf_nstep_td_error(
data: namedtuple,
gamma: float,
nstep: int = 1,
kappa: float = 1.0,
value_gamma: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Overview:
Multistep (1 step or n step) td_error with in FQF, \
referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \
<https://arxiv.org/pdf/1911.02140.pdf>
Arguments:
- data (:obj:`fqf_nstep_td_data`): The input data, fqf_nstep_td_data to calculate loss
- gamma (:obj:`float`): Discount factor
- nstep (:obj:`int`): nstep num, default set to 1
- criterion (:obj:`torch.nn.modules`): Loss function criterion
- beta_function (:obj:`Callable`): The risk function
Returns:
- loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
Shapes:
- data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
['q', 'next_n_q', 'action', 'reward', 'done']
- q (:obj:`torch.FloatTensor`): :math:`(B, tau, N)` i.e. [batch_size, tau, action_dim]
- next_n_q (:obj:`torch.FloatTensor`): :math:`(B, tau', N)`
- action (:obj:`torch.LongTensor`): :math:`(B, )`
- next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
- done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
- quantiles_hats (:obj:`torch.FloatTensor`): :math:`(B, tau)`
Examples:
>>> next_q = torch.randn(4, 3, 3)
>>> done = torch.randn(4)
>>> action = torch.randint(0, 3, size=(4, ))
>>> next_action = torch.randint(0, 3, size=(4, ))
>>> nstep = 3
>>> q = torch.randn(4, 3, 3).requires_grad_(True)
>>> quantiles_hats = torch.randn([4, 3])
>>> reward = torch.rand(nstep, 4)
>>> data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None)
>>> loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep)
"""
q, next_n_q, action, next_n_action, reward, done, quantiles_hats, weight = data
assert len(action.shape) == 1, action.shape
assert len(next_n_action.shape) == 1, next_n_action.shape
assert len(done.shape) == 1, done.shape
assert len(q.shape) == 3, q.shape
assert len(next_n_q.shape) == 3, next_n_q.shape
assert len(reward.shape) == 2, reward.shape
if weight is None:
weight = torch.ones_like(action)
batch_size = done.shape[0]
tau = q.shape[1]
tau_prime = next_n_q.shape[1]
# shape: batch_size x tau x 1
q_s_a = evaluate_quantile_at_action(q, action)
# shape: batch_size x tau_prime x 1
target_q_s_a = evaluate_quantile_at_action(next_n_q, next_n_action)
assert reward.shape[0] == nstep
reward_factor = torch.ones(nstep).to(reward.device)
for i in range(1, nstep):
reward_factor[i] = gamma * reward_factor[i - 1]
reward = torch.matmul(reward_factor, reward) # [batch_size]
if value_gamma is None:
target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1)
else:
target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done
).unsqueeze(-1)
target_q_s_a = target_q_s_a.unsqueeze(-1)
# shape: batch_size x tau' x tau x 1.
bellman_errors = (target_q_s_a.unsqueeze(2) - q_s_a.unsqueeze(1))
# shape: batch_size x tau' x tau x 1
huber_loss = F.smooth_l1_loss(target_q_s_a.unsqueeze(2), q_s_a.unsqueeze(1), reduction="none")
# shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.
quantiles_hats = quantiles_hats[:, None, :, None].repeat([1, tau_prime, 1, 1])
# shape: batch_size x tau_prime x tau x 1.
quantile_huber_loss = (torch.abs(quantiles_hats - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa
# shape: batch_size
loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0]
return (loss * weight).mean(), loss
def evaluate_quantile_at_action(q_s, actions):
assert q_s.shape[0] == actions.shape[0]
batch_size, num_quantiles = q_s.shape[:2]
# Expand actions into (batch_size, num_quantiles, 1).
action_index = actions[:, None, None].expand(batch_size, num_quantiles, 1)
# Calculate quantile values at specified actions.
q_s_a = q_s.gather(dim=2, index=action_index)
return q_s_a
def fqf_calculate_fraction_loss(q_tau_i, q_value, quantiles, actions):
"""
Overview:
Calculate the fraction loss in FQF, \
referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \
<https://arxiv.org/pdf/1911.02140.pdf>
Arguments:
- q_tau_i (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles-1, action_dim)`
- q_value (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles, action_dim)`
- quantiles (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles+1)`
- actions (:obj:`torch.LongTensor`): :math:`(batch_size, )`
Returns:
- fraction_loss (:obj:`torch.Tensor`): fraction loss, 0-dim tensor
"""
assert q_value.requires_grad
batch_size = q_value.shape[0]
num_quantiles = q_value.shape[1]
with torch.no_grad():
sa_quantiles = evaluate_quantile_at_action(q_tau_i, actions)
assert sa_quantiles.shape == (batch_size, num_quantiles - 1, 1)
q_s_a_hats = evaluate_quantile_at_action(q_value, actions) # [batch_size, num_quantiles, 1]
assert q_s_a_hats.shape == (batch_size, num_quantiles, 1)
assert not q_s_a_hats.requires_grad
# NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing.
# I relax this requirements and calculate gradients of quantiles even when
# F^{-1} is not non-decreasing.
values_1 = sa_quantiles - q_s_a_hats[:, :-1]
signs_1 = sa_quantiles > torch.cat([q_s_a_hats[:, :1], sa_quantiles[:, :-1]], dim=1)
assert values_1.shape == signs_1.shape
values_2 = sa_quantiles - q_s_a_hats[:, 1:]
signs_2 = sa_quantiles < torch.cat([sa_quantiles[:, 1:], q_s_a_hats[:, -1:]], dim=1)
assert values_2.shape == signs_2.shape
gradient_of_taus = (torch.where(signs_1, values_1, -values_1) +
torch.where(signs_2, values_2, -values_2)).view(batch_size, num_quantiles - 1)
assert not gradient_of_taus.requires_grad
assert gradient_of_taus.shape == quantiles[:, 1:-1].shape
# Gradients of the network parameters and corresponding loss
# are calculated using chain rule.
fraction_loss = (gradient_of_taus * quantiles[:, 1:-1]).sum(dim=1).mean()
return fraction_loss
td_lambda_data = namedtuple('td_lambda_data', ['value', 'reward', 'weight'])
def shape_fn_td_lambda(args, kwargs):
r"""
Overview:
Return td_lambda shape for hpc
Returns:
shape: [T, B]
"""
if len(args) <= 0:
tmp = kwargs['data'].reward.shape[0]
else:
tmp = args[0].reward.shape
return tmp
@hpc_wrapper(
shape_fn=shape_fn_td_lambda,
namedtuple_data=True,
include_args=[0, 1, 2],
include_kwargs=['data', 'gamma', 'lambda_']
)
def td_lambda_error(data: namedtuple, gamma: float = 0.9, lambda_: float = 0.8) -> torch.Tensor:
"""
Overview:
Computing TD(lambda) loss given constant gamma and lambda.
There is no special handling for terminal state value,
if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal
(*including the terminal state*, values[terminal] should also be 0)
Arguments:
- data (:obj:`namedtuple`): td_lambda input data with fields ['value', 'reward', 'weight']
- gamma (:obj:`float`): Constant discount factor gamma, should be in [0, 1], defaults to 0.9
- lambda (:obj:`float`): Constant lambda, should be in [0, 1], defaults to 0.8
Returns:
- loss (:obj:`torch.Tensor`): Computed MSE loss, averaged over the batch
Shapes:
- value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`, where T is trajectory length and B is batch,\
which is the estimation of the state value at step 0 to T
- reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, the returns from time step 0 to T-1
- weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
- loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
Examples:
>>> T, B = 8, 4
>>> value = torch.randn(T + 1, B).requires_grad_(True)
>>> reward = torch.rand(T, B)
>>> loss = td_lambda_error(td_lambda_data(value, reward, None))
"""
value, reward, weight = data
if weight is None:
weight = torch.ones_like(reward)
with torch.no_grad():
return_ = generalized_lambda_returns(value, reward, gamma, lambda_)
# discard the value at T as it should be considered in the next slice
loss = 0.5 * (F.mse_loss(return_, value[:-1], reduction='none') * weight).mean()
return loss
def generalized_lambda_returns(
bootstrap_values: torch.Tensor,
rewards: torch.Tensor,
gammas: float,
lambda_: float,
done: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
Overview:
Functional equivalent to trfl.value_ops.generalized_lambda_returns
https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74
Passing in a number instead of tensor to make the value constant for all samples in batch
Arguments:
- bootstrap_values (:obj:`torch.Tensor` or :obj:`float`):
estimation of the value at step 0 to *T*, of size [T_traj+1, batchsize]
- rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]
- gammas (:obj:`torch.Tensor` or :obj:`float`):
Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]
- lambda (:obj:`torch.Tensor` or :obj:`float`): Determining the mix of bootstrapping
vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize]
- done (:obj:`torch.Tensor` or :obj:`float`):
Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize]
Returns:
- return (:obj:`torch.Tensor`): Computed lambda return value
for each state from 0 to T-1, of size [T_traj, batchsize]
"""
if not isinstance(gammas, torch.Tensor):
gammas = gammas * torch.ones_like(rewards)
if not isinstance(lambda_, torch.Tensor):
lambda_ = lambda_ * torch.ones_like(rewards)
bootstrap_values_tp1 = bootstrap_values[1:, :]
return multistep_forward_view(bootstrap_values_tp1, rewards, gammas, lambda_, done)
def multistep_forward_view(
bootstrap_values: torch.Tensor,
rewards: torch.Tensor,
gammas: float,
lambda_: float,
done: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
Overview:
Same as trfl.sequence_ops.multistep_forward_view
Implementing (12.18) in Sutton & Barto
```
result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T]
for t in 0...T-2 :
result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1])
```
Assuming the first dim of input tensors correspond to the index in batch
Arguments:
- bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize]
- rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]
- gammas (:obj:`torch.Tensor`): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]
- lambda (:obj:`torch.Tensor`): Determining the mix of bootstrapping vs further accumulation of \
multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored \
and effectively set to 0, as there is no information about future rewards.
- done (:obj:`torch.Tensor` or :obj:`float`):
Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize]
Returns:
- ret (:obj:`torch.Tensor`): Computed lambda return value \
for each state from 0 to T-1, of size [T_traj, batchsize]
"""
result = torch.empty_like(rewards)
if done is None:
done = torch.zeros_like(rewards)
# Forced cutoff at the last one
result[-1, :] = rewards[-1, :] + (1 - done[-1, :]) * gammas[-1, :] * bootstrap_values[-1, :]
discounts = gammas * lambda_
for t in reversed(range(rewards.size()[0] - 1)):
result[t, :] = rewards[t, :] + (1 - done[t, :]) * \
(
discounts[t, :] * result[t + 1, :] +
(gammas[t, :] - discounts[t, :]) * bootstrap_values[t, :]
)
return result