Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class GANHingeLoss(nn.Module): | |
def __init__(self): | |
super(GANHingeLoss, self).__init__() | |
self.relu = nn.ReLU() | |
def __call__(self, pred, is_real, for_discriminator): | |
if for_discriminator: | |
if is_real: | |
return self.relu(1 - pred).mean() | |
return self.relu(1 + pred).mean() | |
assert is_real, "The generator's hinge loss must be aiming for real" | |
return -1.0 * pred.mean() |