File size: 1,494 Bytes
3be620b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import List, Literal

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, layers
from tensorflow.keras.losses import Loss

from .lpips import LPIPS

from ..discriminator.model import NLayerDiscriminator


class VQLPIPSWithDiscriminator(Loss):
    def __init__(
        self, *, pixelloss_weight: float = 1.0, perceptual_weight: float = 1.0, **kwargs
    ):
        super().__init__(**kwargs)
        self.pixelloss_weight = pixelloss_weight
        self.perceptual_loss = LPIPS(reduction=tf.keras.losses.Reduction.NONE)
        self.perceptual_weight = perceptual_weight

    def call(
        self,
        y_true,
        y_pred,
    ):
        reconstruction_loss = tf.abs(y_true - y_pred)
        if self.perceptual_weight > 0:
            perceptual_loss = self.perceptual_loss(y_true, y_pred)
            reconstruction_loss += self.perceptual_weight * perceptual_loss
        else:
            perceptual_loss = 0.0

        neg_log_likelihood = tf.reduce_mean(reconstruction_loss)

        return neg_log_likelihood

        # # GAN part
        # if optimizer_idx == 0:
        #     if cond is None:
        #         assert not self.disc_conditional
        #         logits_fake = self.discriminator(y_pred)
        #     else:
        #         assert self.disc_conditional
        #         logits_fake = self.discriminator(tf.concat([y_pred, cond], axis=-1))
        #     g_loss = -tf.reduce_mean(logits_fake)