Kurokabe's picture
Upload 84 files
3be620b
raw
history blame
5.09 kB
import os
import numpy as np
import tensorflow as tf
import torchvision.models as models
from tensorflow import keras
from tensorflow.keras import Model, Sequential
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.losses import Loss
from pyprojroot.pyprojroot import here
def normalize_tensor(x, eps=1e-10):
norm_factor = tf.sqrt(tf.reduce_sum(x**2, axis=-1, keepdims=True))
return x / (norm_factor + eps)
class LPIPS(Loss):
def __init__(self, use_dropout=True, **kwargs):
super().__init__(**kwargs)
self.use_dropout = use_dropout
self.scaling_layer = ScalingLayer() # preprocess_input
selected_layers = [
"block1_conv2",
"block2_conv2",
"block3_conv3",
"block4_conv3",
"block5_conv3",
]
# TODO here we load the same weights as pytorch, try with tensorflow weights
self.net = self.load_vgg16() # VGG16(weights="imagenet", include_top=False)
self.net.trainable = False
outputs = [self.net.get_layer(layer).output for layer in selected_layers]
self.model = Model(self.net.input, outputs)
self.lins = [
NetLinLayer(input_shape=output.shape[1:], use_dropout=use_dropout)
for output in outputs
]
# TODO: here we use the pytorch weights of the linear layers, try without these layers, or without initializing the weights
self.init_lin_layers()
def load_vgg16(self) -> Model:
"""Load a VGG16 model with the same weights as PyTorch
https://github.com/ezavarygin/vgg16_pytorch2keras
"""
pytorch_model = models.vgg16(pretrained=True)
# select weights in the conv2d layers and transpose them to keras dim ordering:
wblist_torch = list(pytorch_model.parameters())[:26]
wblist_keras = []
for i in range(len(wblist_torch)):
if wblist_torch[i].dim() == 4:
w = np.transpose(wblist_torch[i].detach().numpy(), axes=[2, 3, 1, 0])
wblist_keras.append(w)
elif wblist_torch[i].dim() == 1:
b = wblist_torch[i].detach().numpy()
wblist_keras.append(b)
else:
raise Exception("Fully connected layers are not implemented.")
keras_model = VGG16(include_top=False, weights=None)
keras_model.set_weights(wblist_keras)
return keras_model
def init_lin_layers(self):
for i in range(5):
weights = np.load(
os.path.join(here(), "models", "NetLinLayer", f"numpy_{i}.npy")
)
weights = np.moveaxis(weights, 1, 2)
self.lins[i].set_weights([weights])
def call(self, y_true, y_pred):
scaled_true = self.scaling_layer(y_true)
scaled_pred = self.scaling_layer(y_pred)
outputs_true, outputs_pred = self.model(scaled_true), self.model(scaled_pred)
features_true, features_pred, diffs = {}, {}, {}
for kk in range(len(outputs_true)):
features_true[kk], features_pred[kk] = normalize_tensor(
outputs_true[kk]
), normalize_tensor(outputs_pred[kk])
diffs[kk] = (features_true[kk] - features_pred[kk]) ** 2
res = [
tf.reduce_mean(self.lins[kk](diffs[kk]), axis=(-3, -2), keepdims=True)
for kk in range(len(outputs_true))
]
# return tf.cast(tf.reduce_sum(res), tf.float32)
return tf.reduce_sum(res)
class ScalingLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.shift = tf.Variable([-0.030, -0.088, -0.188])
self.scale = tf.Variable([0.458, 0.448, 0.450])
def call(self, inputs):
if inputs.dtype == tf.float16:
inputs = tf.cast(inputs, tf.float32)
# self.shift = tf.cast(self.shift, tf.float16)
# self.scale = tf.cast(self.scale, tf.float16)
return (inputs - self.shift) / self.scale
class NetLinLayer(layers.Layer):
def __init__(self, input_shape, channels_out=1, use_dropout=False):
super().__init__()
inputs = tf.keras.Input(shape=input_shape)
x = inputs
if use_dropout:
x = layers.Dropout(0.5)(x)
x = layers.Conv2D(channels_out, 1, padding="same", use_bias=False)(x)
x = layers.Activation("linear", dtype="float32")(x)
self.model = Model(inputs=inputs, outputs=x)
# sequence = [layers.Input(input_shape)]
# sequence += (
# [
# layers.Dropout(0.5),
# ]
# if use_dropout
# else []
# )
# sequence += [
# layers.Conv2D(channels_out, 1, padding="same", use_bias=False),
# layers.Activation("linear", dtype="float32"),
# ]
# self.model = Sequential(sequence)
def call(self, inputs):
return self.model(inputs)