zjowowen's picture
init space
079c32c
raw
history blame
3.24 kB
from typing import Optional
import torch
from torch import nn
from torch.distributions.transforms import TanhTransform
class NonegativeParameter(nn.Module):
"""
Overview:
This module will output a non-negative parameter during the forward process.
Interfaces:
``__init__``, ``forward``, ``set_data``.
"""
def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8):
"""
Overview:
Initialize the NonegativeParameter object using the given arguments.
Arguments:
- data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \
default value is 0.
- requires_grad (:obj:`bool`): Whether this parameter requires grad.
- delta (:obj:`Any`): The delta of log function.
"""
super().__init__()
if data is None:
data = torch.zeros(1)
self.log_data = nn.Parameter(torch.log(data + delta), requires_grad=requires_grad)
def forward(self) -> torch.Tensor:
"""
Overview:
Output the non-negative parameter during the forward process.
Returns:
parameter (:obj:`torch.Tensor`): The generated parameter.
"""
return torch.exp(self.log_data)
def set_data(self, data: torch.Tensor) -> None:
"""
Overview:
Set the value of the non-negative parameter.
Arguments:
data (:obj:`torch.Tensor`): The new value of the non-negative parameter.
"""
self.log_data = nn.Parameter(torch.log(data + 1e-8), requires_grad=self.log_data.requires_grad)
class TanhParameter(nn.Module):
"""
Overview:
This module will output a tanh parameter during the forward process.
Interfaces:
``__init__``, ``forward``, ``set_data``.
"""
def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True):
"""
Overview:
Initialize the TanhParameter object using the given arguments.
Arguments:
- data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \
default value is 1.
- requires_grad (:obj:`bool`): Whether this parameter requires grad.
"""
super().__init__()
if data is None:
data = torch.zeros(1)
self.transform = TanhTransform(cache_size=1)
self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=requires_grad)
def forward(self) -> torch.Tensor:
"""
Overview:
Output the tanh parameter during the forward process.
Returns:
parameter (:obj:`torch.Tensor`): The generated parameter.
"""
return self.transform(self.data_inv)
def set_data(self, data: torch.Tensor) -> None:
"""
Overview:
Set the value of the tanh parameter.
Arguments:
data (:obj:`torch.Tensor`): The new value of the tanh parameter.
"""
self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=self.data_inv.requires_grad)