File size: 1,787 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from typing import Optional, Callable, Union

from torch.nn import Module

from tha3.module.module_factory import ModuleFactory
from tha3.nn.init_function import create_init_function
from tha3.nn.nonlinearity_factory import resolve_nonlinearity_factory
from tha3.nn.normalization import NormalizationLayerFactory
from tha3.nn.spectral_norm import apply_spectral_norm


def wrap_conv_or_linear_module(module: Module,
                               initialization_method: Union[str, Callable[[Module], Module]],
                               use_spectral_norm: bool):
    if isinstance(initialization_method, str):
        init = create_init_function(initialization_method)
    else:
        init = initialization_method
    return apply_spectral_norm(init(module), use_spectral_norm)


class BlockArgs:
    def __init__(self,
                 initialization_method: Union[str, Callable[[Module], Module]] = 'he',
                 use_spectral_norm: bool = False,
                 normalization_layer_factory: Optional[NormalizationLayerFactory] = None,
                 nonlinearity_factory: Optional[ModuleFactory] = None):
        self.nonlinearity_factory = resolve_nonlinearity_factory(nonlinearity_factory)
        self.normalization_layer_factory = normalization_layer_factory
        self.use_spectral_norm = use_spectral_norm
        self.initialization_method = initialization_method

    def wrap_module(self, module: Module) -> Module:
        return wrap_conv_or_linear_module(module, self.get_init_func(), self.use_spectral_norm)

    def get_init_func(self) -> Callable[[Module], Module]:
        if isinstance(self.initialization_method, str):
            return create_init_function(self.initialization_method)
        else:
            return self.initialization_method