File size: 1,902 Bytes
6a62ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Optional

from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid

from tha3.module.module_factory import ModuleFactory


class ReLUFactory(ModuleFactory):
    def __init__(self, inplace: bool = False):
        self.inplace = inplace

    def create(self) -> Module:
        return ReLU(self.inplace)


class LeakyReLUFactory(ModuleFactory):
    def __init__(self, inplace: bool = False, negative_slope: float = 1e-2):
        self.negative_slope = negative_slope
        self.inplace = inplace

    def create(self) -> Module:
        return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope)


class ELUFactory(ModuleFactory):
    def __init__(self, inplace: bool = False, alpha: float = 1.0):
        self.alpha = alpha
        self.inplace = inplace

    def create(self) -> Module:
        return ELU(inplace=self.inplace, alpha=self.alpha)


class ReLU6Factory(ModuleFactory):
    def __init__(self, inplace: bool = False):
        self.inplace = inplace

    def create(self) -> Module:
        return ReLU6(inplace=self.inplace)


class SiLUFactory(ModuleFactory):
    def __init__(self, inplace: bool = False):
        self.inplace = inplace

    def create(self) -> Module:
        return SiLU(inplace=self.inplace)


class HardswishFactory(ModuleFactory):
    def __init__(self, inplace: bool = False):
        self.inplace = inplace

    def create(self) -> Module:
        return Hardswish(inplace=self.inplace)


class TanhFactory(ModuleFactory):
    def create(self) -> Module:
        return Tanh()


class SigmoidFactory(ModuleFactory):
    def create(self) -> Module:
        return Sigmoid()


def resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory:
    if nonlinearity_fatory is None:
        return ReLUFactory(inplace=False)
    else:
        return nonlinearity_fatory