from torch.nn import Module | |
from torch.nn.utils import spectral_norm | |
def apply_spectral_norm(module: Module, use_spectrial_norm: bool = False) -> Module: | |
if use_spectrial_norm: | |
return spectral_norm(module) | |
else: | |
return module | |
from torch.nn import Module | |
from torch.nn.utils import spectral_norm | |
def apply_spectral_norm(module: Module, use_spectrial_norm: bool = False) -> Module: | |
if use_spectrial_norm: | |
return spectral_norm(module) | |
else: | |
return module | |