File size: 3,849 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import treetensor.torch as ttorch
from torch.distributions import Normal, Independent


class ArgmaxSampler:
    '''
    Overview:
        Argmax sampler, return the index of the maximum value
    '''

    def __call__(self, logit: torch.Tensor) -> torch.Tensor:
        '''
        Overview:
            Return the index of the maximum value
        Arguments:
            - logit (:obj:`torch.Tensor`): The input tensor
        Returns:
            - action (:obj:`torch.Tensor`): The index of the maximum value
        '''
        return logit.argmax(dim=-1)


class MultinomialSampler:
    '''
    Overview:
        Multinomial sampler, return the index of the sampled value
    '''

    def __call__(self, logit: torch.Tensor) -> torch.Tensor:
        '''
        Overview:
            Return the index of the sampled value
        Arguments:
            - logit (:obj:`torch.Tensor`): The input tensor
        Returns:
            - action (:obj:`torch.Tensor`): The index of the sampled value
        '''
        dist = torch.distributions.Categorical(logits=logit)
        return dist.sample()


class MuSampler:
    '''
    Overview:
        Mu sampler, return the mu of the input tensor
    '''

    def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
        '''
        Overview:
            Return the mu of the input tensor
        Arguments:
            - logit (:obj:`ttorch.Tensor`): The input tensor
        Returns:
            - action (:obj:`torch.Tensor`): The mu of the input tensor
        '''
        return logit.mu


class ReparameterizationSampler:
    '''
    Overview:
        Reparameterization sampler, return the reparameterized value of the input tensor
    '''

    def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
        '''
        Overview:
            Return the reparameterized value of the input tensor
        Arguments:
            - logit (:obj:`ttorch.Tensor`): The input tensor
        Returns:
            - action (:obj:`torch.Tensor`): The reparameterized value of the input tensor
        '''
        dist = Normal(logit.mu, logit.sigma)
        dist = Independent(dist, 1)
        return dist.rsample()


class HybridStochasticSampler:
    '''
    Overview:
        Hybrid stochastic sampler, return the sampled action type and the reparameterized action args
    '''

    def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
        '''
        Overview:
            Return the sampled action type and the reparameterized action args
        Arguments:
            - logit (:obj:`ttorch.Tensor`): The input tensor
        Returns:
            - action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args
        '''
        dist = torch.distributions.Categorical(logits=logit.action_type)
        action_type = dist.sample()
        dist = Normal(logit.action_args.mu, logit.action_args.sigma)
        dist = Independent(dist, 1)
        action_args = dist.rsample()
        return ttorch.as_tensor({
            'action_type': action_type,
            'action_args': action_args,
        })


class HybridDeterminsticSampler:
    '''
    Overview:
        Hybrid deterministic sampler, return the argmax action type and the mu action args
    '''

    def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
        '''
        Overview:
            Return the argmax action type and the mu action args
        Arguments:
            - logit (:obj:`ttorch.Tensor`): The input tensor
        Returns:
            - action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args
        '''
        action_type = logit.action_type.argmax(dim=-1)
        action_args = logit.action_args.mu
        return ttorch.as_tensor({
            'action_type': action_type,
            'action_args': action_args,
        })