Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
1.37 kB
from typing import List
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.losses import Loss
from .lpips import LPIPS
class PerceptualLoss(Loss):
def __init__(self, *, perceptual_weight: float = 1.0, **kwargs):
"""Perceptual loss based on the LPIPS metric.
Args:
perceptual_weight (float, optional): The weight of the perceptual loss. Defaults to 1.0.
"""
super().__init__(**kwargs)
self.perceptual_loss = LPIPS(reduction=tf.keras.losses.Reduction.NONE)
self.perceptual_weight = perceptual_weight
def get_config(self):
config = super().get_config()
config.update(
{
"perceptual_weight": self.perceptual_weight,
}
)
return config
def call(
self,
y_true,
y_pred,
):
reconstruction_loss = tf.abs(y_true - y_pred)
if self.perceptual_weight > 0:
perceptual_loss = self.perceptual_loss(y_true, y_pred)
reconstruction_loss += self.perceptual_weight * perceptual_loss
else:
perceptual_loss = 0.0
neg_log_likelihood = tf.reduce_mean(reconstruction_loss)
return neg_log_likelihood