zjowowen's picture
init space
079c32c
raw
history blame
3.85 kB
import torch
import treetensor.torch as ttorch
from torch.distributions import Normal, Independent
class ArgmaxSampler:
'''
Overview:
Argmax sampler, return the index of the maximum value
'''
def __call__(self, logit: torch.Tensor) -> torch.Tensor:
'''
Overview:
Return the index of the maximum value
Arguments:
- logit (:obj:`torch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The index of the maximum value
'''
return logit.argmax(dim=-1)
class MultinomialSampler:
'''
Overview:
Multinomial sampler, return the index of the sampled value
'''
def __call__(self, logit: torch.Tensor) -> torch.Tensor:
'''
Overview:
Return the index of the sampled value
Arguments:
- logit (:obj:`torch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The index of the sampled value
'''
dist = torch.distributions.Categorical(logits=logit)
return dist.sample()
class MuSampler:
'''
Overview:
Mu sampler, return the mu of the input tensor
'''
def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
'''
Overview:
Return the mu of the input tensor
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The mu of the input tensor
'''
return logit.mu
class ReparameterizationSampler:
'''
Overview:
Reparameterization sampler, return the reparameterized value of the input tensor
'''
def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
'''
Overview:
Return the reparameterized value of the input tensor
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The reparameterized value of the input tensor
'''
dist = Normal(logit.mu, logit.sigma)
dist = Independent(dist, 1)
return dist.rsample()
class HybridStochasticSampler:
'''
Overview:
Hybrid stochastic sampler, return the sampled action type and the reparameterized action args
'''
def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
'''
Overview:
Return the sampled action type and the reparameterized action args
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args
'''
dist = torch.distributions.Categorical(logits=logit.action_type)
action_type = dist.sample()
dist = Normal(logit.action_args.mu, logit.action_args.sigma)
dist = Independent(dist, 1)
action_args = dist.rsample()
return ttorch.as_tensor({
'action_type': action_type,
'action_args': action_args,
})
class HybridDeterminsticSampler:
'''
Overview:
Hybrid deterministic sampler, return the argmax action type and the mu action args
'''
def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
'''
Overview:
Return the argmax action type and the mu action args
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args
'''
action_type = logit.action_type.argmax(dim=-1)
action_args = logit.action_args.mu
return ttorch.as_tensor({
'action_type': action_type,
'action_args': action_args,
})