File size: 3,849 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import treetensor.torch as ttorch
from torch.distributions import Normal, Independent
class ArgmaxSampler:
'''
Overview:
Argmax sampler, return the index of the maximum value
'''
def __call__(self, logit: torch.Tensor) -> torch.Tensor:
'''
Overview:
Return the index of the maximum value
Arguments:
- logit (:obj:`torch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The index of the maximum value
'''
return logit.argmax(dim=-1)
class MultinomialSampler:
'''
Overview:
Multinomial sampler, return the index of the sampled value
'''
def __call__(self, logit: torch.Tensor) -> torch.Tensor:
'''
Overview:
Return the index of the sampled value
Arguments:
- logit (:obj:`torch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The index of the sampled value
'''
dist = torch.distributions.Categorical(logits=logit)
return dist.sample()
class MuSampler:
'''
Overview:
Mu sampler, return the mu of the input tensor
'''
def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
'''
Overview:
Return the mu of the input tensor
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The mu of the input tensor
'''
return logit.mu
class ReparameterizationSampler:
'''
Overview:
Reparameterization sampler, return the reparameterized value of the input tensor
'''
def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
'''
Overview:
Return the reparameterized value of the input tensor
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`torch.Tensor`): The reparameterized value of the input tensor
'''
dist = Normal(logit.mu, logit.sigma)
dist = Independent(dist, 1)
return dist.rsample()
class HybridStochasticSampler:
'''
Overview:
Hybrid stochastic sampler, return the sampled action type and the reparameterized action args
'''
def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
'''
Overview:
Return the sampled action type and the reparameterized action args
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args
'''
dist = torch.distributions.Categorical(logits=logit.action_type)
action_type = dist.sample()
dist = Normal(logit.action_args.mu, logit.action_args.sigma)
dist = Independent(dist, 1)
action_args = dist.rsample()
return ttorch.as_tensor({
'action_type': action_type,
'action_args': action_args,
})
class HybridDeterminsticSampler:
'''
Overview:
Hybrid deterministic sampler, return the argmax action type and the mu action args
'''
def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
'''
Overview:
Return the argmax action type and the mu action args
Arguments:
- logit (:obj:`ttorch.Tensor`): The input tensor
Returns:
- action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args
'''
action_type = logit.action_type.argmax(dim=-1)
action_args = logit.action_args.mu
return ttorch.as_tensor({
'action_type': action_type,
'action_args': action_args,
})
|