File size: 3,503 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os

import tensorflow as tf
from pyprojroot.pyprojroot import here
from tensorflow import reduce_mean
from tensorflow.keras import Model
from tensorflow.keras.applications import VGG19
from tensorflow.keras.applications.vgg19 import preprocess_input
from tensorflow.keras.losses import (
    Loss,
    MeanSquaredError,
    Reduction,
    SparseCategoricalCrossentropy,
    BinaryCrossentropy,
)

from . import vgg19_loss as vgg19


class Losses:
    def __init__(self, num_replicas: int = 1, vgg_model_file: str = None):
        self.num_replicas = num_replicas
        self.SCCE = SparseCategoricalCrossentropy(
            from_logits=True, reduction=Reduction.NONE
        )
        self.MSE = MeanSquaredError(reduction=Reduction.NONE)
        self.MAE = tf.keras.losses.MeanAbsoluteError(reduction=Reduction.NONE)
        self.BCE = BinaryCrossentropy(from_logits=True, reduction=Reduction.NONE)

        self.vgg = VGG.build()
        self.preprocess = preprocess_input
        try:
          root_dir = here()
        except RecursionError:
          root_dir = "GANime"

        self.vgg_model_file = (
            os.path.join(root_dir, "models", "vgg19", "imagenet-vgg-verydeep-19.mat")
            if vgg_model_file is None
            else vgg_model_file
        )

    def bce_loss(self, real, pred):
        # compute binary cross entropy loss without reduction
        loss = self.BCE(real, pred)
        # compute reduced mean over the entire batch
        loss = reduce_mean(loss) * (1.0 / self.num_replicas)
        # return reduced bce loss
        return loss

    def perceptual_loss(self, real, pred):
        y_true_preprocessed = self.preprocess(real)
        y_pred_preprocessed = self.preprocess(pred)
        y_true_scaled = y_true_preprocessed / 12.75
        y_pred_scaled = y_pred_preprocessed / 12.75

        loss = self.mse_loss(y_true_scaled, y_pred_scaled) * 5e3

        return loss

    def scce_loss(self, real, pred):
        # compute categorical cross entropy loss without reduction
        loss = self.SCCE(real, pred)
        # compute reduced mean over the entire batch
        loss = reduce_mean(loss) * (1.0 / self.num_replicas)
        # return reduced scce loss
        return loss

    def mse_loss(self, real, pred):
        # compute mean squared error without reduction
        loss = self.MSE(real, pred)
        # compute reduced mean over the entire batch
        loss = reduce_mean(loss) * (1.0 / self.num_replicas)
        # return reduced mse loss
        return loss

    def mae_loss(self, real, pred):
        # compute mean absolute error without reduction
        loss = self.MAE(real, pred)
        # compute reduced mean over the entire batch
        loss = reduce_mean(loss) * (1.0 / self.num_replicas)
        # return reduced mae loss
        return loss

    def vgg_loss(self, real, pred):
        loss = vgg19.vgg_loss(pred, real, vgg_model_file=self.vgg_model_file)
        return loss

    def style_loss(self, real, pred):
        loss = vgg19.style_loss(
            pred,
            real,
            vgg_model_file=self.vgg_model_file,
        )
        return loss


class VGG:
    @staticmethod
    def build():
        # initialize the pre-trained VGG19 model
        vgg = VGG19(input_shape=(None, None, 3), weights="imagenet", include_top=False)
        # slicing the VGG19 model till layer #20
        model = Model(vgg.input, vgg.layers[20].output)
        # return the sliced VGG19 model
        return model