zhzluke96
update
d2b7e94
raw
history blame
1.83 kB
import logging
import torch
from torch import Tensor, nn
logger = logging.getLogger(__name__)
class Normalizer(nn.Module):
def __init__(self, momentum=0.01, eps=1e-9):
super().__init__()
self.momentum = momentum
self.eps = eps
self.running_mean_unsafe: Tensor
self.running_var_unsafe: Tensor
self.register_buffer("running_mean_unsafe", torch.full([], torch.nan))
self.register_buffer("running_var_unsafe", torch.full([], torch.nan))
@property
def started(self):
return not torch.isnan(self.running_mean_unsafe)
@property
def running_mean(self):
if not self.started:
return torch.zeros_like(self.running_mean_unsafe)
return self.running_mean_unsafe
@property
def running_std(self):
if not self.started:
return torch.ones_like(self.running_var_unsafe)
return (self.running_var_unsafe + self.eps).sqrt()
@torch.no_grad()
def _ema(self, a: Tensor, x: Tensor):
return (1 - self.momentum) * a + self.momentum * x
def update_(self, x):
if not self.started:
self.running_mean_unsafe = x.mean()
self.running_var_unsafe = x.var()
else:
self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
self.running_var_unsafe = self._ema(
self.running_var_unsafe, (x - self.running_mean).pow(2).mean()
)
def forward(self, x: Tensor, update=True):
if self.training and update:
self.update_(x)
self.stats = dict(mean=self.running_mean.item(), std=self.running_std.item())
x = (x - self.running_mean) / self.running_std
return x
def inverse(self, x: Tensor):
return x * self.running_std + self.running_mean