|
from statistics import mode |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.python.keras import Model, Sequential |
|
from tensorflow.python.keras.layers import Dense, LSTMCell, RNN, Conv2D, Conv2DTranspose |
|
from tensorflow.keras.layers import BatchNormalization, TimeDistributed |
|
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU |
|
from tensorflow.keras.layers import Activation |
|
|
|
|
|
|
|
|
|
|
|
from tensorflow.keras.losses import Loss |
|
|
|
|
|
class KLCriterion(Loss): |
|
def call(self, y_true, y_pred): |
|
(mu1, logvar1), (mu2, logvar2) = y_true, y_pred |
|
|
|
"""KL( N(mu_1, sigma2_1) || N(mu_2, sigma2_2))""" |
|
sigma1 = tf.exp(tf.math.multiply(logvar1, 0.5)) |
|
sigma2 = tf.exp(tf.math.multiply(logvar2, 0.5)) |
|
|
|
kld = ( |
|
tf.math.log(sigma2 / sigma1) |
|
+ (tf.exp(logvar1) + tf.square(mu1 - mu2)) / (2 * tf.exp(logvar2)) |
|
- 0.5 |
|
) |
|
return tf.reduce_sum(kld) / 22 |
|
|
|
|
|
class Encoder(Model): |
|
def __init__(self, dim, nc=1): |
|
super().__init__() |
|
self.dim = dim |
|
self.c1 = Sequential( |
|
[ |
|
Conv2D(64, kernel_size=4, strides=2, padding="same"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c2 = Sequential( |
|
[ |
|
Conv2D(128, kernel_size=4, strides=2, padding="same"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c3 = Sequential( |
|
[ |
|
Conv2D(256, kernel_size=4, strides=2, padding="same"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c4 = Sequential( |
|
[ |
|
Conv2D(512, kernel_size=4, strides=2, padding="same"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.c5 = Sequential( |
|
[ |
|
Conv2D(self.dim, kernel_size=4, strides=1, padding="valid"), |
|
BatchNormalization(), |
|
Activation("tanh"), |
|
] |
|
) |
|
|
|
def call(self, input): |
|
h1 = self.c1(input) |
|
h2 = self.c2(h1) |
|
h3 = self.c3(h2) |
|
h4 = self.c4(h3) |
|
h5 = self.c5(h4) |
|
return tf.reshape(h5, (-1, self.dim)), [h1, h2, h3, h4, h5] |
|
|
|
|
|
class Decoder(Model): |
|
def __init__(self, dim, nc=1): |
|
super().__init__() |
|
self.dim = dim |
|
self.upc1 = Sequential( |
|
[ |
|
Conv2DTranspose(512, kernel_size=4, strides=1, padding="valid"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc2 = Sequential( |
|
[ |
|
Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc3 = Sequential( |
|
[ |
|
Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc4 = Sequential( |
|
[ |
|
Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"), |
|
BatchNormalization(), |
|
LeakyReLU(alpha=0.2), |
|
] |
|
) |
|
self.upc5 = Sequential( |
|
[ |
|
Conv2DTranspose(1, kernel_size=4, strides=2, padding="same"), |
|
Activation("sigmoid"), |
|
] |
|
) |
|
|
|
def call(self, input): |
|
vec, skip = input |
|
d1 = self.upc1(tf.reshape(vec, (-1, 1, 1, self.dim))) |
|
d2 = self.upc2(tf.concat([d1, skip[3]], axis=-1)) |
|
d3 = self.upc3(tf.concat([d2, skip[2]], axis=-1)) |
|
d4 = self.upc4(tf.concat([d3, skip[1]], axis=-1)) |
|
output = self.upc5(tf.concat([d4, skip[0]], axis=-1)) |
|
return output |
|
|
|
|
|
class MyLSTM(Model): |
|
def __init__(self, input_shape, hidden_size, output_size, n_layers): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.n_layers = n_layers |
|
self.embed = Dense(hidden_size, input_dim=input_shape) |
|
|
|
|
|
|
|
|
|
self.lstm = LSTMCell(hidden_size) |
|
self.out = Dense(output_size) |
|
|
|
def init_hidden(self, batch_size): |
|
hidden = [] |
|
for i in range(self.n_layers): |
|
hidden.append( |
|
( |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
) |
|
) |
|
self.__dict__["hidden"] = hidden |
|
|
|
def build(self, input_shape): |
|
self.init_hidden(input_shape[0]) |
|
|
|
def call(self, inputs): |
|
h_in = self.embed(inputs) |
|
for i in range(self.n_layers): |
|
_, self.hidden[i] = self.lstm(h_in, self.hidden[i]) |
|
h_in = self.hidden[i][0] |
|
|
|
return self.out(h_in) |
|
|
|
|
|
class MyGaussianLSTM(Model): |
|
def __init__(self, input_shape, hidden_size, output_size, n_layers): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.n_layers = n_layers |
|
self.embed = Dense(hidden_size, input_dim=input_shape) |
|
|
|
|
|
|
|
self.lstm = LSTMCell(hidden_size) |
|
self.mu_net = Dense(output_size) |
|
self.logvar_net = Dense(output_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reparameterize(self, mu, logvar: tf.Tensor): |
|
logvar = tf.math.exp(logvar * 0.5) |
|
eps = tf.random.normal(logvar.shape) |
|
return tf.add(tf.math.multiply(eps, logvar), mu) |
|
|
|
def init_hidden(self, batch_size): |
|
hidden = [] |
|
for i in range(self.n_layers): |
|
hidden.append( |
|
( |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
tf.Variable(tf.zeros([batch_size, self.hidden_size])), |
|
) |
|
) |
|
self.__dict__["hidden"] = hidden |
|
|
|
def build(self, input_shape): |
|
self.init_hidden(input_shape[0]) |
|
|
|
def call(self, inputs): |
|
h_in = self.embed(inputs) |
|
for i in range(self.n_layers): |
|
|
|
|
|
_, self.hidden[i] = self.lstm(h_in, self.hidden[i]) |
|
h_in = self.hidden[i][0] |
|
mu = self.mu_net(h_in) |
|
logvar = self.logvar_net(h_in) |
|
z = self.reparameterize(mu, logvar) |
|
return z, mu, logvar |
|
|
|
|
|
class P2P(Model): |
|
def __init__( |
|
self, |
|
channels: int = 1, |
|
g_dim: int = 128, |
|
z_dim: int = 10, |
|
rnn_size: int = 256, |
|
prior_rnn_layers: int = 1, |
|
posterior_rnn_layers: int = 1, |
|
predictor_rnn_layers: float = 1, |
|
skip_prob: float = 0.5, |
|
n_past: int = 1, |
|
last_frame_skip: bool = False, |
|
beta: float = 0.0001, |
|
weight_align: float = 0.1, |
|
weight_cpc: float = 100, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.g_dim = g_dim |
|
self.z_dim = z_dim |
|
self.rnn_size = rnn_size |
|
self.prior_rnn_layers = prior_rnn_layers |
|
self.posterior_rnn_layers = posterior_rnn_layers |
|
self.predictor_rnn_layers = predictor_rnn_layers |
|
|
|
self.skip_prob = skip_prob |
|
self.n_past = n_past |
|
self.last_frame_skip = last_frame_skip |
|
self.beta = beta |
|
self.weight_align = weight_align |
|
self.weight_cpc = weight_cpc |
|
|
|
self.frame_predictor = MyLSTM( |
|
self.g_dim + self.z_dim + 1 + 1, |
|
self.rnn_size, |
|
self.g_dim, |
|
self.predictor_rnn_layers, |
|
) |
|
|
|
self.prior = MyGaussianLSTM( |
|
self.g_dim + self.g_dim + 1 + 1, |
|
self.rnn_size, |
|
self.z_dim, |
|
self.prior_rnn_layers, |
|
) |
|
|
|
self.posterior = MyGaussianLSTM( |
|
self.g_dim + self.g_dim + 1 + 1, |
|
self.rnn_size, |
|
self.z_dim, |
|
self.posterior_rnn_layers, |
|
) |
|
|
|
self.encoder = Encoder(self.g_dim, self.channels) |
|
self.decoder = Decoder(self.g_dim, self.channels) |
|
|
|
|
|
self.mse_criterion = tf.keras.losses.MeanSquaredError() |
|
self.kl_criterion = KLCriterion() |
|
self.align_criterion = tf.keras.losses.MeanSquaredError() |
|
|
|
|
|
self.frame_predictor_optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=0.0001 |
|
) |
|
self.posterior_optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=0.0001 |
|
) |
|
self.prior_optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=0.0001 |
|
) |
|
self.encoder_optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=0.0001 |
|
) |
|
self.decoder_optimizer = tf.keras.optimizers.Adam( |
|
learning_rate=0.0001 |
|
) |
|
|
|
def get_global_descriptor(self, x, start_ix=0, cp_ix=None): |
|
"""Get the global descriptor based on x, start_ix, cp_ix.""" |
|
if cp_ix is None: |
|
cp_ix = x.shape[1] - 1 |
|
|
|
x_cp = x[:, cp_ix, ...] |
|
h_cp = self.encoder(x_cp)[0] |
|
|
|
return x_cp, h_cp |
|
|
|
def call(self, x, start_ix=0, cp_ix=-1): |
|
batch_size = x.shape[0] |
|
|
|
with tf.GradientTape(persistent=True) as tape: |
|
mse_loss = 0 |
|
kld_loss = 0 |
|
cpc_loss = 0 |
|
align_loss = 0 |
|
|
|
seq_len = x.shape[1] |
|
start_ix = 0 |
|
cp_ix = seq_len - 1 |
|
x_cp, global_z = self.get_global_descriptor( |
|
x, start_ix, cp_ix |
|
) |
|
|
|
skip_prob = self.skip_prob |
|
|
|
prev_i = 0 |
|
max_skip_count = seq_len * skip_prob |
|
skip_count = 0 |
|
probs = np.random.uniform(low=0, high=1, size=seq_len - 1) |
|
|
|
for i in range(1, seq_len): |
|
if ( |
|
probs[i - 1] <= skip_prob |
|
and i >= self.n_past |
|
and skip_count < max_skip_count |
|
and i != 1 |
|
and i != cp_ix |
|
): |
|
skip_count += 1 |
|
continue |
|
|
|
time_until_cp = tf.fill([batch_size, 1], (cp_ix - i + 1) / cp_ix) |
|
delta_time = tf.fill([batch_size, 1], ((i - prev_i) / cp_ix)) |
|
prev_i = i |
|
|
|
h = self.encoder(x[:, i - 1, ...]) |
|
h_target = self.encoder(x[:, i, ...])[0] |
|
|
|
if self.last_frame_skip or i <= self.n_past: |
|
h, skip = h |
|
else: |
|
h = h[0] |
|
|
|
|
|
h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1) |
|
h_target_cpaw = tf.concat( |
|
[h_target, global_z, time_until_cp, delta_time], axis=1 |
|
) |
|
zt, mu, logvar = self.posterior(h_target_cpaw) |
|
zt_p, mu_p, logvar_p = self.prior(h_cpaw) |
|
|
|
concat = tf.concat([h, zt, time_until_cp, delta_time], axis=1) |
|
h_pred = self.frame_predictor(concat) |
|
x_pred = self.decoder([h_pred, skip]) |
|
|
|
if i == cp_ix: |
|
h_pred_p = self.frame_predictor( |
|
tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) |
|
) |
|
x_pred_p = self.decoder([h_pred_p, skip]) |
|
cpc_loss = self.mse_criterion(x_pred_p, x_cp) |
|
|
|
if i > 1: |
|
align_loss += self.align_criterion(h[0], h_pred) |
|
|
|
mse_loss += self.mse_criterion(x_pred, x[:, i, ...]) |
|
kld_loss += self.kl_criterion((mu, logvar), (mu_p, logvar_p)) |
|
|
|
|
|
loss = mse_loss + kld_loss * self.beta + align_loss * self.weight_align |
|
|
|
prior_loss = kld_loss + cpc_loss * self.weight_cpc |
|
|
|
var_list_frame_predictor = self.frame_predictor.trainable_variables |
|
var_list_posterior = self.posterior.trainable_variables |
|
var_list_prior = self.prior.trainable_variables |
|
var_list_encoder = self.encoder.trainable_variables |
|
var_list_decoder = self.decoder.trainable_variables |
|
|
|
|
|
|
|
|
|
|
|
var_list_without_prior = ( |
|
var_list_frame_predictor |
|
+ var_list_posterior |
|
+ var_list_encoder |
|
+ var_list_decoder |
|
) |
|
|
|
gradients_without_prior = tape.gradient( |
|
loss, |
|
var_list_without_prior, |
|
) |
|
gradients_prior = tape.gradient( |
|
prior_loss, |
|
var_list_prior, |
|
) |
|
|
|
self.update_model_without_prior( |
|
gradients_without_prior, |
|
var_list_without_prior, |
|
) |
|
self.update_prior(gradients_prior, var_list_prior) |
|
del tape |
|
|
|
return ( |
|
mse_loss / seq_len, |
|
kld_loss / seq_len, |
|
cpc_loss / seq_len, |
|
align_loss / seq_len, |
|
) |
|
|
|
def p2p_generate( |
|
self, |
|
x, |
|
len_output, |
|
eval_cp_ix, |
|
start_ix=0, |
|
cp_ix=-1, |
|
model_mode="full", |
|
skip_frame=False, |
|
init_hidden=True, |
|
): |
|
batch_size, num_frames, h, w, channels = x.shape |
|
dim_shape = (h, w, channels) |
|
|
|
gen_seq = [x[:, 0, ...]] |
|
x_in = x[:, 0, ...] |
|
|
|
seq_len = x.shape[1] |
|
cp_ix = seq_len - 1 |
|
|
|
x_cp, global_z = self.get_global_descriptor( |
|
x, cp_ix=cp_ix |
|
) |
|
|
|
skip_prob = self.skip_prob |
|
|
|
prev_i = 0 |
|
max_skip_count = seq_len * skip_prob |
|
skip_count = 0 |
|
probs = np.random.uniform(0, 1, len_output - 1) |
|
|
|
for i in range(1, len_output): |
|
if ( |
|
probs[i - 1] <= skip_prob |
|
and i >= self.n_past |
|
and skip_count < max_skip_count |
|
and i != 1 |
|
and i != (len_output - 1) |
|
and skip_frame |
|
): |
|
skip_count += 1 |
|
gen_seq.append(tf.zeros_like(x_in)) |
|
continue |
|
|
|
time_until_cp = tf.fill([batch_size, 1], (eval_cp_ix - i + 1) / eval_cp_ix) |
|
|
|
delta_time = tf.fill([batch_size, 1], ((i - prev_i) / eval_cp_ix)) |
|
|
|
prev_i = i |
|
|
|
h = self.encoder(x_in) |
|
|
|
if self.last_frame_skip or i == 1 or i < self.n_past: |
|
h, skip = h |
|
else: |
|
h, _ = h |
|
|
|
h_cpaw = tf.concat([h, global_z, time_until_cp, delta_time], axis=1) |
|
|
|
if i < self.n_past: |
|
h_target = self.encoder(x[:, i, ...])[0] |
|
h_target_cpaw = tf.concat( |
|
[h_target, global_z, time_until_cp, delta_time], axis=1 |
|
) |
|
|
|
zt, _, _ = self.posterior(h_target_cpaw) |
|
zt_p, _, _ = self.prior(h_cpaw) |
|
|
|
if model_mode == "posterior" or model_mode == "full": |
|
self.frame_predictor( |
|
tf.concat([h, zt, time_until_cp, delta_time], axis=1) |
|
) |
|
elif model_mode == "prior": |
|
self.frame_predictor( |
|
tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) |
|
) |
|
|
|
x_in = x[:, i, ...] |
|
gen_seq.append(x_in) |
|
else: |
|
if i < num_frames: |
|
h_target = self.encoder(x[:, i, ...])[0] |
|
h_target_cpaw = tf.concat( |
|
[h_target, global_z, time_until_cp, delta_time], axis=1 |
|
) |
|
else: |
|
h_target_cpaw = h_cpaw |
|
|
|
zt, _, _ = self.posterior(h_target_cpaw) |
|
zt_p, _, _ = self.prior(h_cpaw) |
|
|
|
if model_mode == "posterior": |
|
h = self.frame_predictor( |
|
tf.concat([h, zt, time_until_cp, delta_time], axis=1) |
|
) |
|
elif model_mode == "prior" or model_mode == "full": |
|
h = self.frame_predictor( |
|
tf.concat([h, zt_p, time_until_cp, delta_time], axis=1) |
|
) |
|
|
|
x_in = self.decoder([h, skip]) |
|
gen_seq.append(x_in) |
|
return tf.stack(gen_seq, axis=1) |
|
|
|
def update_model_without_prior(self, gradients, var_list): |
|
self.frame_predictor_optimizer.apply_gradients(zip(gradients, var_list)) |
|
self.posterior_optimizer.apply_gradients(zip(gradients, var_list)) |
|
self.encoder_optimizer.apply_gradients(zip(gradients, var_list)) |
|
self.decoder_optimizer.apply_gradients(zip(gradients, var_list)) |
|
|
|
def update_prior(self, gradients, var_list): |
|
self.prior_optimizer.apply_gradients(zip(gradients, var_list)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|