musika_api / models.py
nakas's picture
musika clone
050507e
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils.layer_utils import count_params
from layers import AddNoise
class Models_functions:
def __init__(self, args):
self.args = args
if self.args.mixed_precision:
self.mixed_precision = tf.keras.mixed_precision
self.policy = tf.keras.mixed_precision.Policy("mixed_float16")
tf.keras.mixed_precision.set_global_policy(self.policy)
self.init = tf.keras.initializers.he_uniform()
def conv_util(
self, inp, filters, kernel_size=(1, 3), strides=(1, 1), noise=False, upsample=False, padding="same", bnorm=True
):
x = inp
bias = True
if bnorm:
bias = False
if upsample:
x = tf.keras.layers.Conv2DTranspose(
filters,
kernel_size=kernel_size,
strides=strides,
activation="linear",
padding=padding,
kernel_initializer=self.init,
use_bias=bias,
)(x)
else:
x = tf.keras.layers.Conv2D(
filters,
kernel_size=kernel_size,
strides=strides,
activation="linear",
padding=padding,
kernel_initializer=self.init,
use_bias=bias,
)(x)
if noise:
x = AddNoise(self.args.datatype)(x)
if bnorm:
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.activations.swish(x)
return x
def pixel_shuffle(self, x, factor=2):
bs_dim, h_dim, w_dim, c_dim = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
x = tf.reshape(x, [bs_dim, h_dim, w_dim, c_dim // factor, factor])
x = tf.transpose(x, [0, 1, 2, 4, 3])
return tf.reshape(x, [bs_dim, h_dim, w_dim * factor, c_dim // factor])
def adain(self, x, emb, name):
emb = tf.keras.layers.Conv2D(
x.shape[-1],
kernel_size=(1, 1),
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name=name,
)(emb)
x = x / (tf.math.reduce_std(x, -2, keepdims=True) + 1e-5)
return x * emb
def conv_util_gen(
self,
inp,
filters,
kernel_size=(1, 9),
strides=(1, 1),
noise=False,
upsample=False,
emb=None,
se1=None,
name="0",
):
x = inp
if upsample:
x = tf.keras.layers.Conv2DTranspose(
filters,
kernel_size=kernel_size,
strides=strides,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name=name + "c",
)(x)
else:
x = tf.keras.layers.Conv2D(
filters,
kernel_size=kernel_size,
strides=strides,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name=name + "c",
)(x)
if noise:
x = AddNoise(self.args.datatype, name=name + "r")(x)
if emb is not None:
x = self.adain(x, emb, name=name + "ai")
else:
x = tf.keras.layers.BatchNormalization(name=name + "bn")(x)
x = tf.keras.activations.swish(x)
return x
def res_block_disc(self, inp, filters, kernel_size=(1, 3), kernel_size_2=None, strides=(1, 1), name="0"):
if kernel_size_2 is None:
kernel_size_2 = kernel_size
x = tf.keras.layers.Conv2D(
inp.shape[-1],
kernel_size=kernel_size_2,
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
name=name + "c0",
)(inp)
x = tf.keras.layers.LeakyReLU(0.2)(x)
x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x
x = tf.keras.layers.Conv2D(
filters,
kernel_size=kernel_size,
strides=strides,
activation="linear",
padding="same",
kernel_initializer=self.init,
name=name + "c1",
)(x)
x = tf.keras.layers.LeakyReLU(0.2)(x)
x = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * x
if strides != (1, 1):
inp = tf.keras.layers.AveragePooling2D(strides, padding="same")(inp)
if inp.shape[-1] != filters:
inp = tf.keras.layers.Conv2D(
filters,
kernel_size=1,
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=False,
name=name + "c3",
)(inp)
return x + inp
def build_encoder2(self):
inpf = tf.keras.layers.Input((1, self.args.shape, self.args.hop // 4))
inpfls = tf.split(inpf, 8, -2)
inpb = tf.concat(inpfls, 0)
g0 = self.conv_util(inpb, self.args.hop, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False)
g1 = self.conv_util(
g0, self.args.hop + self.args.hop // 2, kernel_size=(1, 3), strides=(1, 2), padding="valid", bnorm=False
)
g2 = self.conv_util(
g1, self.args.hop + self.args.hop // 2, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False
)
g3 = self.conv_util(g2, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), padding="valid", bnorm=False)
g4 = self.conv_util(g3, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 1), padding="same", bnorm=False)
g5 = self.conv_util(g4, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), padding="valid", bnorm=False)
g5 = self.conv_util(g5, self.args.hop * 3, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
g = tf.keras.layers.Conv2D(
self.args.latdepth,
kernel_size=(1, 1),
strides=1,
padding="valid",
kernel_initializer=self.init,
name="cbottle",
activation="tanh",
)(g5)
gls = tf.split(g, 8, 0)
g = tf.concat(gls, -2)
gls = tf.split(g, 2, -2)
g = tf.concat(gls, 0)
gf = tf.cast(g, tf.float32)
return tf.keras.Model(inpf, gf, name="ENC2")
def build_decoder2(self):
inpf = tf.keras.layers.Input((1, self.args.shape // 32, self.args.latdepth))
g = inpf
g = self.conv_util(
g, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, bnorm=False
)
g = self.conv_util(
g,
self.args.hop * 2 + self.args.hop // 2,
kernel_size=(1, 4),
strides=(1, 2),
upsample=True,
noise=True,
bnorm=False,
)
g = self.conv_util(
g,
self.args.hop * 2 + self.args.hop // 2,
kernel_size=(1, 3),
strides=(1, 1),
upsample=False,
noise=True,
bnorm=False,
)
g = self.conv_util(
g, self.args.hop * 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
)
g = self.conv_util(
g, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 1), upsample=False, noise=True, bnorm=False
)
g = self.conv_util(
g,
self.args.hop + self.args.hop // 2,
kernel_size=(1, 4),
strides=(1, 2),
upsample=True,
noise=True,
bnorm=False,
)
g = self.conv_util(g, self.args.hop, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False)
gf = tf.keras.layers.Conv2D(
self.args.hop // 4, kernel_size=(1, 1), strides=1, padding="same", kernel_initializer=self.init, name="cout"
)(g)
gfls = tf.split(gf, 2, 0)
gf = tf.concat(gfls, -2)
gf = tf.cast(gf, tf.float32)
return tf.keras.Model(inpf, gf, name="DEC2")
def build_encoder(self):
dim = ((4 * self.args.hop) // 2) + 1
inpf = tf.keras.layers.Input((dim, self.args.shape, 1))
ginp = tf.transpose(inpf, [0, 3, 2, 1])
g0 = self.conv_util(ginp, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
g1 = self.conv_util(g0, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
g2 = self.conv_util(g1, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
g4 = self.conv_util(g2, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
g5 = self.conv_util(g4, self.args.hop * 4, kernel_size=(1, 1), strides=(1, 1), padding="valid", bnorm=False)
g = tf.keras.layers.Conv2D(
self.args.hop // 4, kernel_size=(1, 1), strides=1, padding="valid", kernel_initializer=self.init
)(g5)
g = tf.keras.activations.tanh(g)
gls = tf.split(g, 2, -2)
g = tf.concat(gls, 0)
gf = tf.cast(g, tf.float32)
return tf.keras.Model(inpf, gf, name="ENC")
def build_decoder(self):
dim = ((4 * self.args.hop) // 2) + 1
inpf = tf.keras.layers.Input((1, self.args.shape // 2, self.args.hop // 4))
g = inpf
g0 = self.conv_util(g, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 1), noise=True, bnorm=False)
g1 = self.conv_util(g0, self.args.hop * 3, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
g2 = self.conv_util(g1, self.args.hop * 2, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
g3 = self.conv_util(g2, self.args.hop, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
g = self.conv_util(g3, self.args.hop, kernel_size=(1, 3), strides=(1, 2), noise=True, bnorm=False)
g33 = self.conv_util(
g, self.args.hop, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
)
g22 = self.conv_util(
g3, self.args.hop * 2, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
)
g11 = self.conv_util(
g22 + g2, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
)
g00 = self.conv_util(
g11 + g1, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
)
g = tf.keras.layers.Conv2D(
dim, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same"
)(g00 + g0)
gf = tf.clip_by_value(g, -1.0, 1.0)
g = self.conv_util(
g22, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
)
g = self.conv_util(
g + g11, self.args.hop * 3, kernel_size=(1, 4), strides=(1, 2), upsample=True, noise=True, bnorm=False
)
g = tf.keras.layers.Conv2D(
dim, kernel_size=(1, 1), strides=(1, 1), kernel_initializer=self.init, padding="same"
)(g + g00)
pf = tf.clip_by_value(g, -1.0, 1.0)
gfls = tf.split(gf, self.args.shape // self.args.window, 0)
gf = tf.concat(gfls, -2)
pfls = tf.split(pf, self.args.shape // self.args.window, 0)
pf = tf.concat(pfls, -2)
s = tf.transpose(gf, [0, 2, 3, 1])
p = tf.transpose(pf, [0, 2, 3, 1])
s = tf.cast(tf.squeeze(s, -1), tf.float32)
p = tf.cast(tf.squeeze(p, -1), tf.float32)
return tf.keras.Model(inpf, [s, p], name="DEC")
def build_critic(self):
sinp = tf.keras.layers.Input(shape=(1, self.args.latlen, self.args.latdepth * 2))
sf = tf.keras.layers.Conv2D(
self.args.base_channels * 3,
kernel_size=(1, 4),
strides=(1, 2),
activation="linear",
padding="same",
kernel_initializer=self.init,
name="1c",
)(sinp)
sf = tf.keras.layers.LeakyReLU(0.2)(sf)
sf = self.res_block_disc(sf, self.args.base_channels * 4, kernel_size=(1, 4), strides=(1, 2), name="2")
sf = self.res_block_disc(sf, self.args.base_channels * 5, kernel_size=(1, 4), strides=(1, 2), name="3")
sf = self.res_block_disc(sf, self.args.base_channels * 6, kernel_size=(1, 4), strides=(1, 2), name="4")
sf = self.res_block_disc(sf, self.args.base_channels * 7, kernel_size=(1, 4), strides=(1, 2), name="5")
if not self.args.small:
sf = self.res_block_disc(
sf, self.args.base_channels * 7, kernel_size=(1, 4), strides=(1, 2), kernel_size_2=(1, 1), name="6"
)
sf = tf.keras.layers.Conv2D(
self.args.base_channels * 7,
kernel_size=(1, 3),
strides=(1, 1),
activation="linear",
padding="same",
kernel_initializer=self.init,
name="7c",
)(sf)
sf = tf.keras.layers.LeakyReLU(0.2)(sf)
gf = tf.keras.layers.Dense(1, activation="linear", use_bias=True, kernel_initializer=self.init, name="7d")(
tf.keras.layers.Flatten()(sf)
)
gf = tf.cast(gf, tf.float32)
return tf.keras.Model(sinp, gf, name="C")
def build_generator(self):
dim = self.args.latdepth * 2
inpf = tf.keras.layers.Input((self.args.latlen, self.args.latdepth * 2))
inpfls = tf.split(inpf, 2, -2)
inpb = tf.concat(inpfls, 0)
inpg = tf.reduce_mean(inpb, -2)
inp1 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(tf.expand_dims(inpb, -3))
inp2 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp1)
inp3 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp2)
inp4 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp3)
inp5 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp4)
if not self.args.small:
inp6 = tf.keras.layers.AveragePooling2D((1, 2), padding="valid")(inp5)
if not self.args.small:
g = tf.keras.layers.Dense(
4 * (self.args.base_channels * 7),
activation="linear",
use_bias=True,
kernel_initializer=self.init,
name="00d",
)(tf.keras.layers.Flatten()(inp6))
g = tf.keras.layers.Reshape((1, 4, self.args.base_channels * 7))(g)
g = AddNoise(self.args.datatype, name="00n")(g)
g = self.adain(g, inp5, name="00ai")
g = tf.keras.activations.swish(g)
else:
g = tf.keras.layers.Dense(
4 * (self.args.base_channels * 7),
activation="linear",
use_bias=True,
kernel_initializer=self.init,
name="00d",
)(tf.keras.layers.Flatten()(inp5))
g = tf.keras.layers.Reshape((1, 4, self.args.base_channels * 7))(g)
g = AddNoise(self.args.datatype, name="00n")(g)
g = self.adain(g, inp4, name="00ai")
g = tf.keras.activations.swish(g)
if not self.args.small:
g1 = self.conv_util_gen(
g,
self.args.base_channels * 6,
kernel_size=(1, 4),
strides=(1, 2),
upsample=True,
noise=True,
emb=inp4,
name="0",
)
g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
g1 = self.conv_util_gen(
g1,
self.args.base_channels * 6,
kernel_size=(1, 4),
strides=(1, 1),
upsample=False,
noise=True,
emb=inp4,
name="1",
)
g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
g1 = g1 + tf.keras.layers.Conv2D(
g1.shape[-1],
kernel_size=(1, 1),
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name="res1c",
)(self.pixel_shuffle(g))
else:
g1 = self.conv_util_gen(
g,
self.args.base_channels * 6,
kernel_size=(1, 1),
strides=(1, 1),
upsample=False,
noise=True,
emb=inp4,
name="0_small",
)
g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
g1 = self.conv_util_gen(
g1,
self.args.base_channels * 6,
kernel_size=(1, 1),
strides=(1, 1),
upsample=False,
noise=True,
emb=inp4,
name="1_small",
)
g1 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g1
g1 = g1 + tf.keras.layers.Conv2D(
g1.shape[-1],
kernel_size=(1, 1),
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name="res1c_small",
)(g)
g2 = self.conv_util_gen(
g1,
self.args.base_channels * 5,
kernel_size=(1, 4),
strides=(1, 2),
upsample=True,
noise=True,
emb=inp3,
name="2",
)
g2 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g2
g2 = self.conv_util_gen(
g2,
self.args.base_channels * 5,
kernel_size=(1, 4),
strides=(1, 1),
upsample=False,
noise=True,
emb=inp3,
name="3",
)
g2 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g2
g2 = g2 + tf.keras.layers.Conv2D(
g2.shape[-1],
kernel_size=(1, 1),
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name="res2c",
)(self.pixel_shuffle(g1))
g3 = self.conv_util_gen(
g2,
self.args.base_channels * 4,
kernel_size=(1, 4),
strides=(1, 2),
upsample=True,
noise=True,
emb=inp2,
name="4",
)
g3 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g3
g3 = self.conv_util_gen(
g3,
self.args.base_channels * 4,
kernel_size=(1, 4),
strides=(1, 1),
upsample=False,
noise=True,
emb=inp2,
name="5",
)
g3 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g3
g3 = g3 + tf.keras.layers.Conv2D(
g3.shape[-1],
kernel_size=(1, 1),
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name="res3c",
)(self.pixel_shuffle(g2))
g4 = self.conv_util_gen(
g3,
self.args.base_channels * 3,
kernel_size=(1, 4),
strides=(1, 2),
upsample=True,
noise=True,
emb=inp1,
name="6",
)
g4 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g4
g4 = self.conv_util_gen(
g4,
self.args.base_channels * 3,
kernel_size=(1, 4),
strides=(1, 1),
upsample=False,
noise=True,
emb=inp1,
name="7",
)
g4 = tf.math.sqrt(tf.cast(0.5, self.args.datatype)) * g4
g4 = g4 + tf.keras.layers.Conv2D(
g4.shape[-1],
kernel_size=(1, 1),
strides=1,
activation="linear",
padding="same",
kernel_initializer=self.init,
use_bias=True,
name="res4c",
)(self.pixel_shuffle(g3))
g5 = self.conv_util_gen(
g4,
self.args.base_channels * 2,
kernel_size=(1, 4),
strides=(1, 2),
upsample=True,
noise=True,
emb=tf.expand_dims(tf.cast(inpb, dtype=self.args.datatype), -3),
name="8",
)
gf = tf.keras.layers.Conv2D(
dim, kernel_size=(1, 4), strides=(1, 1), kernel_initializer=self.init, padding="same", name="9c"
)(g5)
gfls = tf.split(gf, 2, 0)
gf = tf.concat(gfls, -2)
gf = tf.cast(gf, tf.float32)
return tf.keras.Model(inpf, gf, name="GEN")
# Load past models from path to resume training or test
def load(self, path, load_dec=False):
gen = self.build_generator()
critic = self.build_critic()
enc = self.build_encoder()
dec = self.build_decoder()
enc2 = self.build_encoder2()
dec2 = self.build_decoder2()
gen_ema = self.build_generator()
switch = tf.Variable(-1.0, dtype=tf.float32)
if self.args.mixed_precision:
opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
else:
opt_disc = tf.keras.optimizers.Adam(0.0001, 0.9)
opt_dec = tf.keras.optimizers.Adam(0.0001, 0.9)
if load_dec:
dec.load_weights(self.args.dec_path + "/dec.h5")
dec2.load_weights(self.args.dec_path + "/dec2.h5")
enc.load_weights(self.args.dec_path + "/enc.h5")
enc2.load_weights(self.args.dec_path + "/enc2.h5")
else:
grad_vars = critic.trainable_weights
zero_grads = [tf.zeros_like(w) for w in grad_vars]
opt_disc.apply_gradients(zip(zero_grads, grad_vars))
grad_vars = gen.trainable_variables
zero_grads = [tf.zeros_like(w) for w in grad_vars]
opt_dec.apply_gradients(zip(zero_grads, grad_vars))
if not self.args.testing:
opt_disc.set_weights(np.load(path + "/opt_disc.npy", allow_pickle=True))
opt_dec.set_weights(np.load(path + "/opt_dec.npy", allow_pickle=True))
critic.load_weights(path + "/critic.h5")
gen.load_weights(path + "/gen.h5")
switch = tf.Variable(float(np.load(path + "/switch.npy", allow_pickle=True)), dtype=tf.float32)
gen_ema.load_weights(path + "/gen_ema.h5")
dec.load_weights(self.args.dec_path + "/dec.h5")
dec2.load_weights(self.args.dec_path + "/dec2.h5")
enc.load_weights(self.args.dec_path + "/enc.h5")
enc2.load_weights(self.args.dec_path + "/enc2.h5")
return (
critic,
gen,
enc,
dec,
enc2,
dec2,
gen_ema,
[opt_dec, opt_disc],
switch,
)
def build(self):
gen = self.build_generator()
critic = self.build_critic()
enc = self.build_encoder()
dec = self.build_decoder()
enc2 = self.build_encoder2()
dec2 = self.build_decoder2()
gen_ema = self.build_generator()
switch = tf.Variable(-1.0, dtype=tf.float32)
gen_ema = tf.keras.models.clone_model(gen)
gen_ema.set_weights(gen.get_weights())
if self.args.mixed_precision:
opt_disc = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
opt_dec = self.mixed_precision.LossScaleOptimizer(tf.keras.optimizers.Adam(0.0001, 0.5))
else:
opt_disc = tf.keras.optimizers.Adam(0.0001, 0.5)
opt_dec = tf.keras.optimizers.Adam(0.0001, 0.5)
return (
critic,
gen,
enc,
dec,
enc2,
dec2,
gen_ema,
[opt_dec, opt_disc],
switch,
)
def get_networks(self):
(
critic,
gen,
enc,
dec,
enc2,
dec2,
gen_ema_1,
[opt_dec, opt_disc],
switch,
) = self.load(self.args.load_path_1, load_dec=False)
print(f"Networks loaded from {self.args.load_path_1}")
(
critic,
gen,
enc,
dec,
enc2,
dec2,
gen_ema_2,
[opt_dec, opt_disc],
switch,
) = self.load(self.args.load_path_2, load_dec=False)
print(f"Networks loaded from {self.args.load_path_2}")
(
critic,
gen,
enc,
dec,
enc2,
dec2,
gen_ema_3,
[opt_dec, opt_disc],
switch,
) = self.load(self.args.load_path_3, load_dec=False)
print(f"Networks loaded from {self.args.load_path_3}")
return (
(critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch),
(critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch),
(critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch),
)
def initialize_networks(self):
(
(critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch),
(critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch),
(critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch),
) = self.get_networks()
print(f"Critic params: {count_params(critic.trainable_variables)}")
print(f"Generator params: {count_params(gen.trainable_variables)}")
return (
(critic, gen, enc, dec, enc2, dec2, gen_ema_1, [opt_dec, opt_disc], switch),
(critic, gen, enc, dec, enc2, dec2, gen_ema_2, [opt_dec, opt_disc], switch),
(critic, gen, enc, dec, enc2, dec2, gen_ema_3, [opt_dec, opt_disc], switch),
)