Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from imaginaire.third_party.bias_act.bias_act import FusedNonlinearity | |
class ScaledLeakyReLU(nn.Module): | |
def __init__(self, negative_slope=0.2, scale=2 ** 0.5, inplace=False): | |
super().__init__() | |
self.negative_slope = negative_slope | |
self.scale = scale | |
self.inplace = inplace | |
def forward(self, x): | |
return F.leaky_relu(x, self.negative_slope, inplace=self.inplace) * self.scale | |
# return _fused_scaled_leakyrelu(x, self.negative_slope, self.inplace, self.scale) | |
# @torch.jit.script | |
# def _fused_scaled_leakyrelu(x: torch.Tensor, negative_slope: float, inplace: bool, scale: float): | |
# return F.leaky_relu(x, negative_slope, inplace=inplace) * scale | |
def get_nonlinearity_layer(nonlinearity_type, inplace, **kwargs): | |
r"""Return a nonlinearity layer. | |
Args: | |
nonlinearity_type (str): | |
Type of nonlinear activation function. | |
``'none'``, ``'relu'``, ``'leakyrelu'``, ``'prelu'``, | |
``'tanh'`` , ``'sigmoid'`` or ``'softmax'``. | |
inplace (bool): If ``True``, set ``inplace=True`` when initializing | |
the nonlinearity layer. | |
""" | |
if nonlinearity_type.startswith('fused'): | |
nonlinearity = FusedNonlinearity(nonlinearity=nonlinearity_type[6:], **kwargs) | |
elif nonlinearity_type == 'relu': | |
nonlinearity = nn.ReLU(inplace=inplace) | |
elif nonlinearity_type == 'leakyrelu': | |
nonlinearity = nn.LeakyReLU(0.2, inplace=inplace) | |
elif nonlinearity_type == 'scaled_leakyrelu': | |
nonlinearity = ScaledLeakyReLU(0.2, inplace=inplace) | |
import imaginaire.config | |
if imaginaire.config.USE_JIT: | |
nonlinearity = torch.jit.script(nonlinearity) | |
elif nonlinearity_type == 'prelu': | |
nonlinearity = nn.PReLU() | |
elif nonlinearity_type == 'tanh': | |
nonlinearity = nn.Tanh() | |
elif nonlinearity_type == 'sigmoid': | |
nonlinearity = nn.Sigmoid() | |
elif nonlinearity_type.startswith('softmax'): | |
dim = nonlinearity_type.split(',')[1] if ',' in nonlinearity_type else 1 | |
nonlinearity = nn.Softmax(dim=int(dim)) | |
elif nonlinearity_type == 'none' or nonlinearity_type == '': | |
nonlinearity = None | |
else: | |
raise ValueError('Nonlinearity %s is not recognized' % nonlinearity_type) | |
return nonlinearity | |