File size: 3,745 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import layer_norm
from torch.nn import Module, BatchNorm2d, InstanceNorm2d, Parameter
from torch.nn.init import normal_, constant_

from tha3.nn.pass_through import PassThrough


class PixelNormalization(Module):
    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, x):
        return x / torch.sqrt((x ** 2).mean(dim=1, keepdim=True) + self.epsilon)


class NormalizationLayerFactory(ABC):
    def __init__(self):
        super().__init__()

    @abstractmethod
    def create(self, num_features: int, affine: bool = True) -> Module:
        pass

    @staticmethod
    def resolve_2d(factory: Optional['NormalizationLayerFactory']) -> 'NormalizationLayerFactory':
        if factory is None:
            return InstanceNorm2dFactory()
        else:
            return factory


class Bias2d(Module):
    def __init__(self, num_features: int):
        super().__init__()
        self.num_features = num_features
        self.bias = Parameter(torch.zeros(1, num_features, 1, 1))

    def forward(self, x):
        return x + self.bias


class NoNorm2dFactory(NormalizationLayerFactory):
    def __init__(self):
        super().__init__()

    def create(self, num_features: int, affine: bool = True) -> Module:
        if affine:
            return Bias2d(num_features)
        else:
            return PassThrough()


class BatchNorm2dFactory(NormalizationLayerFactory):
    def __init__(self,
                 weight_mean: Optional[float] = None,
                 weight_std: Optional[float] = None,
                 bias: Optional[float] = None):
        super().__init__()
        self.bias = bias
        self.weight_std = weight_std
        self.weight_mean = weight_mean

    def get_weight_mean(self):
        if self.weight_mean is None:
            return 1.0
        else:
            return self.weight_mean

    def get_weight_std(self):
        if self.weight_std is None:
            return 0.02
        else:
            return self.weight_std

    def create(self, num_features: int, affine: bool = True) -> Module:
        module = BatchNorm2d(num_features=num_features, affine=affine)
        if affine:
            if self.weight_mean is not None or self.weight_std is not None:
                normal_(module.weight, self.get_weight_mean(), self.get_weight_std())
            if self.bias is not None:
                constant_(module.bias, self.bias)
        return module


class InstanceNorm2dFactory(NormalizationLayerFactory):
    def __init__(self):
        super().__init__()

    def create(self, num_features: int, affine: bool = True) -> Module:
        return InstanceNorm2d(num_features=num_features, affine=affine)


class PixelNormFactory(NormalizationLayerFactory):
    def __init__(self):
        super().__init__()

    def create(self, num_features: int, affine: bool = True) -> Module:
        return PixelNormalization()


class LayerNorm2d(Module):
    def __init__(self, channels: int, affine: bool = True):
        super(LayerNorm2d, self).__init__()
        self.channels = channels
        self.affine = affine

        if self.affine:
            self.weight = Parameter(torch.ones(1, channels, 1, 1))
            self.bias = Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x):
        shape = x.size()[1:]
        y = layer_norm(x, shape) * self.weight + self.bias
        return y

class LayerNorm2dFactory(NormalizationLayerFactory):
    def __init__(self):
        super().__init__()

    def create(self, num_features: int, affine: bool = True) -> Module:
        return LayerNorm2d(channels=num_features, affine=affine)