File size: 1,902 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from typing import Optional
from torch.nn import Module, ReLU, LeakyReLU, ELU, ReLU6, Hardswish, SiLU, Tanh, Sigmoid
from tha3.module.module_factory import ModuleFactory
class ReLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return ReLU(self.inplace)
class LeakyReLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False, negative_slope: float = 1e-2):
self.negative_slope = negative_slope
self.inplace = inplace
def create(self) -> Module:
return LeakyReLU(inplace=self.inplace, negative_slope=self.negative_slope)
class ELUFactory(ModuleFactory):
def __init__(self, inplace: bool = False, alpha: float = 1.0):
self.alpha = alpha
self.inplace = inplace
def create(self) -> Module:
return ELU(inplace=self.inplace, alpha=self.alpha)
class ReLU6Factory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return ReLU6(inplace=self.inplace)
class SiLUFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return SiLU(inplace=self.inplace)
class HardswishFactory(ModuleFactory):
def __init__(self, inplace: bool = False):
self.inplace = inplace
def create(self) -> Module:
return Hardswish(inplace=self.inplace)
class TanhFactory(ModuleFactory):
def create(self) -> Module:
return Tanh()
class SigmoidFactory(ModuleFactory):
def create(self) -> Module:
return Sigmoid()
def resolve_nonlinearity_factory(nonlinearity_fatory: Optional[ModuleFactory]) -> ModuleFactory:
if nonlinearity_fatory is None:
return ReLUFactory(inplace=False)
else:
return nonlinearity_fatory
|