File size: 2,176 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import Callable
import torch
from torch import zero_
from torch.nn import Module
from torch.nn.init import kaiming_normal_, xavier_normal_, normal_
def create_init_function(method: str = 'none') -> Callable[[Module], Module]:
def init(module: Module):
if method == 'none':
return module
elif method == 'he':
kaiming_normal_(module.weight)
return module
elif method == 'xavier':
xavier_normal_(module.weight)
return module
elif method == 'dcgan':
normal_(module.weight, 0.0, 0.02)
return module
elif method == 'dcgan_001':
normal_(module.weight, 0.0, 0.01)
return module
elif method == "zero":
with torch.no_grad():
zero_(module.weight)
return module
else:
raise ("Invalid initialization method %s" % method)
return init
class HeInitialization:
def __init__(self, a: int = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'):
self.nonlinearity = nonlinearity
self.mode = mode
self.a = a
def __call__(self, module: Module) -> Module:
with torch.no_grad():
kaiming_normal_(module.weight, a=self.a, mode=self.mode, nonlinearity=self.nonlinearity)
return module
class NormalInitialization:
def __init__(self, mean: float = 0.0, std: float = 1.0):
self.std = std
self.mean = mean
def __call__(self, module: Module) -> Module:
with torch.no_grad():
normal_(module.weight, self.mean, self.std)
return module
class XavierInitialization:
def __init__(self, gain: float = 1.0):
self.gain = gain
def __call__(self, module: Module) -> Module:
with torch.no_grad():
xavier_normal_(module.weight, self.gain)
return module
class ZeroInitialization:
def __call__(self, module: Module) -> Module:
with torch.no_grad:
zero_(module.weight)
return module
class NoInitialization:
def __call__(self, module: Module) -> Module:
return module |