|
import os |
|
|
|
import tensorflow as tf |
|
from pyprojroot.pyprojroot import here |
|
from tensorflow import reduce_mean |
|
from tensorflow.keras import Model |
|
from tensorflow.keras.applications import VGG19 |
|
from tensorflow.keras.applications.vgg19 import preprocess_input |
|
from tensorflow.keras.losses import ( |
|
Loss, |
|
MeanSquaredError, |
|
Reduction, |
|
SparseCategoricalCrossentropy, |
|
BinaryCrossentropy, |
|
) |
|
|
|
from . import vgg19_loss as vgg19 |
|
|
|
|
|
class Losses: |
|
def __init__(self, num_replicas: int = 1, vgg_model_file: str = None): |
|
self.num_replicas = num_replicas |
|
self.SCCE = SparseCategoricalCrossentropy( |
|
from_logits=True, reduction=Reduction.NONE |
|
) |
|
self.MSE = MeanSquaredError(reduction=Reduction.NONE) |
|
self.MAE = tf.keras.losses.MeanAbsoluteError(reduction=Reduction.NONE) |
|
self.BCE = BinaryCrossentropy(from_logits=True, reduction=Reduction.NONE) |
|
|
|
self.vgg = VGG.build() |
|
self.preprocess = preprocess_input |
|
try: |
|
root_dir = here() |
|
except RecursionError: |
|
root_dir = "GANime" |
|
|
|
self.vgg_model_file = ( |
|
os.path.join(root_dir, "models", "vgg19", "imagenet-vgg-verydeep-19.mat") |
|
if vgg_model_file is None |
|
else vgg_model_file |
|
) |
|
|
|
def bce_loss(self, real, pred): |
|
|
|
loss = self.BCE(real, pred) |
|
|
|
loss = reduce_mean(loss) * (1.0 / self.num_replicas) |
|
|
|
return loss |
|
|
|
def perceptual_loss(self, real, pred): |
|
y_true_preprocessed = self.preprocess(real) |
|
y_pred_preprocessed = self.preprocess(pred) |
|
y_true_scaled = y_true_preprocessed / 12.75 |
|
y_pred_scaled = y_pred_preprocessed / 12.75 |
|
|
|
loss = self.mse_loss(y_true_scaled, y_pred_scaled) * 5e3 |
|
|
|
return loss |
|
|
|
def scce_loss(self, real, pred): |
|
|
|
loss = self.SCCE(real, pred) |
|
|
|
loss = reduce_mean(loss) * (1.0 / self.num_replicas) |
|
|
|
return loss |
|
|
|
def mse_loss(self, real, pred): |
|
|
|
loss = self.MSE(real, pred) |
|
|
|
loss = reduce_mean(loss) * (1.0 / self.num_replicas) |
|
|
|
return loss |
|
|
|
def mae_loss(self, real, pred): |
|
|
|
loss = self.MAE(real, pred) |
|
|
|
loss = reduce_mean(loss) * (1.0 / self.num_replicas) |
|
|
|
return loss |
|
|
|
def vgg_loss(self, real, pred): |
|
loss = vgg19.vgg_loss(pred, real, vgg_model_file=self.vgg_model_file) |
|
return loss |
|
|
|
def style_loss(self, real, pred): |
|
loss = vgg19.style_loss( |
|
pred, |
|
real, |
|
vgg_model_file=self.vgg_model_file, |
|
) |
|
return loss |
|
|
|
|
|
class VGG: |
|
@staticmethod |
|
def build(): |
|
|
|
vgg = VGG19(input_shape=(None, None, 3), weights="imagenet", include_top=False) |
|
|
|
model = Model(vgg.input, vgg.layers[20].output) |
|
|
|
return model |
|
|