|
from typing import Callable |
|
|
|
import torch |
|
from torch import zero_ |
|
from torch.nn import Module |
|
from torch.nn.init import kaiming_normal_, xavier_normal_, normal_ |
|
|
|
|
|
def create_init_function(method: str = 'none') -> Callable[[Module], Module]: |
|
def init(module: Module): |
|
if method == 'none': |
|
return module |
|
elif method == 'he': |
|
kaiming_normal_(module.weight) |
|
return module |
|
elif method == 'xavier': |
|
xavier_normal_(module.weight) |
|
return module |
|
elif method == 'dcgan': |
|
normal_(module.weight, 0.0, 0.02) |
|
return module |
|
elif method == 'dcgan_001': |
|
normal_(module.weight, 0.0, 0.01) |
|
return module |
|
elif method == "zero": |
|
with torch.no_grad(): |
|
zero_(module.weight) |
|
return module |
|
else: |
|
raise ("Invalid initialization method %s" % method) |
|
|
|
return init |
|
|
|
|
|
class HeInitialization: |
|
def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'): |
|
self.nonlinearity = nonlinearity |
|
self.mode = mode |
|
self.a = a |
|
|
|
def __call__(self, module: Module) -> Module: |
|
with torch.no_grad(): |
|
kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity) |
|
return module |
|
|
|
|
|
class NormalInitialization: |
|
def __init__(self, mean: float = 0.0, std: float = 1.0): |
|
self.std = std |
|
self.mean = mean |
|
|
|
def __call__(self, module: Module) -> Module: |
|
with torch.no_grad(): |
|
normal_(module.weight, self.mean, self.std) |
|
return module |
|
|
|
|
|
class XavierInitialization: |
|
def __init__(self, gain: float = 1.0): |
|
self.gain = gain |
|
|
|
def __call__(self, module: Module) -> Module: |
|
with torch.no_grad(): |
|
xavier_normal_(module.weight, self.gain) |
|
return module |
|
|
|
|
|
class ZeroInitialization: |
|
def __call__(self, module: Module) -> Module: |
|
with torch.no_grad: |
|
zero_(module.weight) |
|
return module |
|
|
|
class NoInitialization: |
|
def __call__(self, module: Module) -> Module: |
|
return module |