from collections import namedtuple import torch import torch.nn.functional as F from torch.distributions import Independent, Normal a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight']) a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss']) def a2c_error(data: namedtuple) -> namedtuple: """ Overview: Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space Arguments: - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` Returns: - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor Shapes: - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim - action (:obj:`torch.LongTensor`): :math:`(B, )` - value (:obj:`torch.FloatTensor`): :math:`(B, )` - adv (:obj:`torch.FloatTensor`): :math:`(B, )` - return (:obj:`torch.FloatTensor`): :math:`(B, )` - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor - value_loss (:obj:`torch.FloatTensor`): :math:`()` - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` Examples: >>> data = a2c_data( >>> logit=torch.randn(2, 3), >>> action=torch.randint(0, 3, (2, )), >>> value=torch.randn(2, ), >>> adv=torch.randn(2, ), >>> return_=torch.randn(2, ), >>> weight=torch.ones(2, ), >>> ) >>> loss = a2c_error(data) """ logit, action, value, adv, return_, weight = data if weight is None: weight = torch.ones_like(value) dist = torch.distributions.categorical.Categorical(logits=logit) logp = dist.log_prob(action) entropy_loss = (dist.entropy() * weight).mean() policy_loss = -(logp * adv * weight).mean() value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() return a2c_loss(policy_loss, value_loss, entropy_loss) def a2c_error_continuous(data: namedtuple) -> namedtuple: """ Overview: Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space Arguments: - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` Returns: - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor Shapes: - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim - action (:obj:`torch.LongTensor`): :math:`(B, N)` - value (:obj:`torch.FloatTensor`): :math:`(B, )` - adv (:obj:`torch.FloatTensor`): :math:`(B, )` - return (:obj:`torch.FloatTensor`): :math:`(B, )` - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor - value_loss (:obj:`torch.FloatTensor`): :math:`()` - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` Examples: >>> data = a2c_data( >>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)}, >>> action=torch.randn(2, 3), >>> value=torch.randn(2, ), >>> adv=torch.randn(2, ), >>> return_=torch.randn(2, ), >>> weight=torch.ones(2, ), >>> ) >>> loss = a2c_error_continuous(data) """ logit, action, value, adv, return_, weight = data if weight is None: weight = torch.ones_like(value) dist = Independent(Normal(logit['mu'], logit['sigma']), 1) logp = dist.log_prob(action) entropy_loss = (dist.entropy() * weight).mean() policy_loss = -(logp * adv * weight).mean() value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() return a2c_loss(policy_loss, value_loss, entropy_loss)