from typing import List try: from typing import Literal except ImportError: from typing_extensions import Literal import numpy as np import tensorflow as tf from tensorflow import keras from tensorflow.keras.losses import Loss from .lpips import LPIPS class PerceptualLoss(Loss): def __init__(self, *, perceptual_weight: float = 1.0, **kwargs): """Perceptual loss based on the LPIPS metric. Args: perceptual_weight (float, optional): The weight of the perceptual loss. Defaults to 1.0. """ super().__init__(**kwargs) self.perceptual_loss = LPIPS(reduction=tf.keras.losses.Reduction.NONE) self.perceptual_weight = perceptual_weight def get_config(self): config = super().get_config() config.update( { "perceptual_weight": self.perceptual_weight, } ) return config def call( self, y_true, y_pred, ): reconstruction_loss = tf.abs(y_true - y_pred) if self.perceptual_weight > 0: perceptual_loss = self.perceptual_loss(y_true, y_pred) reconstruction_loss += self.perceptual_weight * perceptual_loss else: perceptual_loss = 0.0 neg_log_likelihood = tf.reduce_mean(reconstruction_loss) return neg_log_likelihood