File size: 4,688 Bytes
3be620b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import numpy as np
import tensorflow as tf
import torchvision.models as models
from tensorflow import keras
from tensorflow.keras import Model, Sequential
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.losses import Loss
from pyprojroot.pyprojroot import here
def normalize_tensor(x, eps=1e-10):
norm_factor = tf.sqrt(tf.reduce_sum(x**2, axis=-1, keepdims=True))
return x / (norm_factor + eps)
class LPIPS(Loss):
def __init__(self, use_dropout=True, **kwargs):
super().__init__(**kwargs)
self.scaling_layer = ScalingLayer() # preprocess_input
selected_layers = [
"block1_conv2",
"block2_conv2",
"block3_conv3",
"block4_conv3",
"block5_conv3",
]
# TODO here we load the same weights as pytorch, try with tensorflow weights
self.net = self.load_vgg16() # VGG16(weights="imagenet", include_top=False)
self.net.trainable = False
outputs = [self.net.get_layer(layer).output for layer in selected_layers]
self.model = Model(self.net.input, outputs)
self.lins = [NetLinLayer(use_dropout=use_dropout) for _ in selected_layers]
# TODO: here we use the pytorch weights of the linear layers, try without these layers, or without initializing the weights
self(tf.zeros((1, 16, 16, 1)), tf.zeros((1, 16, 16, 1)))
self.init_lin_layers()
def load_vgg16(self) -> Model:
"""Load a VGG16 model with the same weights as PyTorch
https://github.com/ezavarygin/vgg16_pytorch2keras
"""
pytorch_model = models.vgg16(pretrained=True)
# select weights in the conv2d layers and transpose them to keras dim ordering:
wblist_torch = list(pytorch_model.parameters())[:26]
wblist_keras = []
for i in range(len(wblist_torch)):
if wblist_torch[i].dim() == 4:
w = np.transpose(wblist_torch[i].detach().numpy(), axes=[2, 3, 1, 0])
wblist_keras.append(w)
elif wblist_torch[i].dim() == 1:
b = wblist_torch[i].detach().numpy()
wblist_keras.append(b)
else:
raise Exception("Fully connected layers are not implemented.")
keras_model = VGG16(include_top=False, weights=None)
keras_model.set_weights(wblist_keras)
return keras_model
def init_lin_layers(self):
for i in range(5):
weights = np.load(
os.path.join(here(), "models", "NetLinLayer", f"numpy_{i}.npy")
)
weights = np.moveaxis(weights, 1, 2)
self.lins[i].model.layers[1].set_weights([weights])
def call(self, y_true, y_pred):
scaled_true = self.scaling_layer(y_true)
scaled_pred = self.scaling_layer(y_pred)
outputs_true, outputs_pred = self.model(scaled_true), self.model(scaled_pred)
features_true, features_pred, diffs = {}, {}, {}
for kk in range(len(outputs_true)):
features_true[kk], features_pred[kk] = normalize_tensor(
outputs_true[kk]
), normalize_tensor(outputs_pred[kk])
diffs[kk] = (features_true[kk] - features_pred[kk]) ** 2
res = [
tf.reduce_mean(self.lins[kk](diffs[kk]), axis=(-3, -2), keepdims=True)
for kk in range(len(outputs_true))
]
return tf.reduce_sum(res)
# h1_list = self.model(self.scaling_layer(y_true))
# h2_list = self.model(self.scaling_layer(y_pred))
# rc_loss = 0.0
# for h1, h2 in zip(h1_list, h2_list):
# h1 = K.batch_flatten(h1)
# h2 = K.batch_flatten(h2)
# rc_loss += K.sum(K.square(h1 - h2), axis=-1)
# return rc_loss
class ScalingLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.shift = tf.Variable([-0.030, -0.088, -0.188])
self.scale = tf.Variable([0.458, 0.448, 0.450])
def call(self, inputs):
return (inputs - self.shift) / self.scale
class NetLinLayer(layers.Layer):
def __init__(self, channels_out=1, use_dropout=False):
super().__init__()
sequence = (
[
layers.Dropout(0.5),
]
if use_dropout
else []
)
sequence += [
layers.Conv2D(channels_out, 1, padding="same", use_bias=False),
]
self.model = Sequential(sequence)
def call(self, inputs):
return self.model(inputs)
|