|
from typing import Optional |
|
import torch |
|
from torch import nn |
|
from torch.distributions.transforms import TanhTransform |
|
|
|
|
|
class NonegativeParameter(nn.Module): |
|
""" |
|
Overview: |
|
This module will output a non-negative parameter during the forward process. |
|
Interfaces: |
|
``__init__``, ``forward``, ``set_data``. |
|
""" |
|
|
|
def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8): |
|
""" |
|
Overview: |
|
Initialize the NonegativeParameter object using the given arguments. |
|
Arguments: |
|
- data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ |
|
default value is 0. |
|
- requires_grad (:obj:`bool`): Whether this parameter requires grad. |
|
- delta (:obj:`Any`): The delta of log function. |
|
""" |
|
super().__init__() |
|
if data is None: |
|
data = torch.zeros(1) |
|
self.log_data = nn.Parameter(torch.log(data + delta), requires_grad=requires_grad) |
|
|
|
def forward(self) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Output the non-negative parameter during the forward process. |
|
Returns: |
|
parameter (:obj:`torch.Tensor`): The generated parameter. |
|
""" |
|
return torch.exp(self.log_data) |
|
|
|
def set_data(self, data: torch.Tensor) -> None: |
|
""" |
|
Overview: |
|
Set the value of the non-negative parameter. |
|
Arguments: |
|
data (:obj:`torch.Tensor`): The new value of the non-negative parameter. |
|
""" |
|
self.log_data = nn.Parameter(torch.log(data + 1e-8), requires_grad=self.log_data.requires_grad) |
|
|
|
|
|
class TanhParameter(nn.Module): |
|
""" |
|
Overview: |
|
This module will output a tanh parameter during the forward process. |
|
Interfaces: |
|
``__init__``, ``forward``, ``set_data``. |
|
""" |
|
|
|
def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True): |
|
""" |
|
Overview: |
|
Initialize the TanhParameter object using the given arguments. |
|
Arguments: |
|
- data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ |
|
default value is 1. |
|
- requires_grad (:obj:`bool`): Whether this parameter requires grad. |
|
""" |
|
super().__init__() |
|
if data is None: |
|
data = torch.zeros(1) |
|
self.transform = TanhTransform(cache_size=1) |
|
|
|
self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=requires_grad) |
|
|
|
def forward(self) -> torch.Tensor: |
|
""" |
|
Overview: |
|
Output the tanh parameter during the forward process. |
|
Returns: |
|
parameter (:obj:`torch.Tensor`): The generated parameter. |
|
""" |
|
return self.transform(self.data_inv) |
|
|
|
def set_data(self, data: torch.Tensor) -> None: |
|
""" |
|
Overview: |
|
Set the value of the tanh parameter. |
|
Arguments: |
|
data (:obj:`torch.Tensor`): The new value of the tanh parameter. |
|
""" |
|
self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=self.data_inv.requires_grad) |
|
|