gomoku / DI-engine /ding /torch_utils /distribution.py
zjowowen's picture
init space
079c32c
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Tuple, Dict
import torch
import numpy as np
import torch.nn.functional as F
class Pd(object):
"""
Overview:
Abstract class for parameterizable probability distributions and sampling functions.
Interfaces:
``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample``
.. tip::
In dereived classes, `logits` should be an attribute member stored in class.
"""
def neglogp(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Calculate cross_entropy between input x and logits
Arguments:
- x (:obj:`torch.Tensor`): the input tensor
Return:
- cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss
"""
raise NotImplementedError
def entropy(self) -> torch.Tensor:
"""
Overview:
Calculate the softmax entropy of logits
Arguments:
- reduction (:obj:`str`): support [None, 'mean'], default set to 'mean'
Returns:
- entropy (:obj:`torch.Tensor`): the calculated entropy
"""
raise NotImplementedError
def noise_mode(self):
"""
Overview:
Add noise to logits. This method is designed for randomness
"""
raise NotImplementedError
def mode(self):
"""
Overview:
Return logits argmax result. This method is designed for deterministic.
"""
raise NotImplementedError
def sample(self):
"""
Overview:
Sample from logits's distribution by using softmax. This method is designed for multinomial.
"""
raise NotImplementedError
class CategoricalPd(Pd):
"""
Overview:
Catagorical probility distribution sampler
Interfaces:
``__init__``, ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample``
"""
def __init__(self, logits: torch.Tensor = None) -> None:
"""
Overview:
Init the Pd with logits
Arguments:
- logits (:obj:torch.Tensor): logits to sample from
"""
self.update_logits(logits)
def update_logits(self, logits: torch.Tensor) -> None:
"""
Overview:
Updata logits
Arguments:
- logits (:obj:`torch.Tensor`): logits to update
"""
self.logits = logits
def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor:
"""
Overview:
Calculate cross_entropy between input x and logits
Arguments:
- x (:obj:`torch.Tensor`): the input tensor
- reduction (:obj:`str`): support [None, 'mean'], default set to mean
Return:
- cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss
"""
return F.cross_entropy(self.logits, x, reduction=reduction)
def entropy(self, reduction: str = 'mean') -> torch.Tensor:
"""
Overview:
Calculate the softmax entropy of logits
Arguments:
- reduction (:obj:`str`): support [None, 'mean'], default set to mean
Returns:
- entropy (:obj:`torch.Tensor`): the calculated entropy
"""
a = self.logits - self.logits.max(dim=-1, keepdim=True)[0]
ea = torch.exp(a)
z = ea.sum(dim=-1, keepdim=True)
p = ea / z
entropy = (p * (torch.log(z) - a)).sum(dim=-1)
assert (reduction in [None, 'mean'])
if reduction is None:
return entropy
elif reduction == 'mean':
return entropy.mean()
def noise_mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
"""
Overview:
add noise to logits
Arguments:
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:`torch.Tensor`): noised logits
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
"""
u = torch.rand_like(self.logits)
u = -torch.log(-torch.log(u))
noise_logits = self.logits + u
result = noise_logits.argmax(dim=-1)
if viz:
viz_feature = {}
viz_feature['logits'] = self.logits.data.cpu().numpy()
viz_feature['noise'] = u.data.cpu().numpy()
viz_feature['noise_logits'] = noise_logits.data.cpu().numpy()
return result, viz_feature
else:
return result
def mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
"""
Overview:
return logits argmax result
Arguments:
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits;
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:`torch.Tensor`): the logits argmax result
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
"""
result = self.logits.argmax(dim=-1)
if viz:
viz_feature = {}
viz_feature['logits'] = self.logits.data.cpu().numpy()
return result, viz_feature
else:
return result
def sample(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
"""
Overview:
Sample from logits's distribution by using softmax
Arguments:
- viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \
Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
Returns:
- result (:obj:`torch.Tensor`): the logits sampled result
- viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
"""
p = torch.softmax(self.logits, dim=1)
result = torch.multinomial(p, 1).squeeze(1)
if viz:
viz_feature = {}
viz_feature['logits'] = self.logits.data.cpu().numpy()
return result, viz_feature
else:
return result
class CategoricalPdPytorch(torch.distributions.Categorical):
"""
Overview:
Wrapped ``torch.distributions.Categorical``
Interfaces:
``__init__``, ``update_logits``, ``update_probs``, ``sample``, ``neglogp``, ``mode``, ``entropy``
"""
def __init__(self, probs: torch.Tensor = None) -> None:
"""
Overview:
Initialize the CategoricalPdPytorch object.
Arguments:
- probs (:obj:`torch.Tensor`): The tensor of probabilities.
"""
if probs is not None:
self.update_probs(probs)
def update_logits(self, logits: torch.Tensor) -> None:
"""
Overview:
Updata logits
Arguments:
- logits (:obj:`torch.Tensor`): logits to update
"""
super().__init__(logits=logits)
def update_probs(self, probs: torch.Tensor) -> None:
"""
Overview:
Updata probs
Arguments:
- probs (:obj:`torch.Tensor`): probs to update
"""
super().__init__(probs=probs)
def sample(self) -> torch.Tensor:
"""
Overview:
Sample from logits's distribution by using softmax
Return:
- result (:obj:`torch.Tensor`): the logits sampled result
"""
return super().sample()
def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
"""
Overview:
Calculate cross_entropy between input x and logits
Arguments:
- actions (:obj:`torch.Tensor`): the input action tensor
- reduction (:obj:`str`): support [None, 'mean'], default set to mean
Return:
- cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss
"""
neglogp = super().log_prob(actions)
assert (reduction in ['none', 'mean'])
if reduction == 'none':
return neglogp
elif reduction == 'mean':
return neglogp.mean(dim=0)
def mode(self) -> torch.Tensor:
"""
Overview:
Return logits argmax result
Return:
- result(:obj:`torch.Tensor`): the logits argmax result
"""
return self.probs.argmax(dim=-1)
def entropy(self, reduction: str = None) -> torch.Tensor:
"""
Overview:
Calculate the softmax entropy of logits
Arguments:
- reduction (:obj:`str`): support [None, 'mean'], default set to mean
Returns:
- entropy (:obj:`torch.Tensor`): the calculated entropy
"""
entropy = super().entropy()
assert (reduction in [None, 'mean'])
if reduction is None:
return entropy
elif reduction == 'mean':
return entropy.mean()