File size: 599 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from torch.distributions import Normal, Independent
import torch
# policy_logits = {'mu': torch.randn([1, 2]), 'sigma': abs(torch.randn([1, 2]))}
policy_logits = {'mu': torch.randn([1, 2]), 'sigma': torch.zeros([1, 2]) + 1e-7}
num_of_sampled_actions = 20
(mu, sigma) = policy_logits['mu'], policy_logits['sigma']
dist = Independent(Normal(mu, sigma), 1)
# dist = Normal(mu, sigma)
print(dist.batch_shape, dist.event_shape)
sampled_actions = dist.sample(torch.tensor([num_of_sampled_actions]))
log_prob = dist.log_prob(sampled_actions)
# log_prob = dist.log_prob(sampled_actions).unsqueeze(-1)
|