youngseng's picture
Upload 187 files
da855ff
""" S2G Training """
import datetime
import json
import os
import pdb
import random
import sys
from pathlib import Path
import numpy as np
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
from anim import quat
from anim.tquat import *
from anim.txform import *
from dataset import SGDataset
from helpers import flatten_dict, save_useful_info, progress
from modules import (
Decoder,
SpeechEncoder,
StyleEncoder,
compute_KL_div,
normalize,
)
from optimizers import RAdam
from utils import write_bvh
def train(
models_dir,
logs_dir,
path_processed_data,
path_data_definition,
train_options,
network_options,
):
# ===============================================
# Getting/Setting Training/Network Configs
# ===============================================
np.random.seed(train_options["seed"])
torch.manual_seed(train_options["seed"])
torch.set_num_threads(train_options["thread_count"])
use_gpu = train_options["use_gpu"] and torch.cuda.is_available()
use_script = train_options["use_script"]
if use_gpu:
print("Using GPU!")
else:
print("Using CPU!")
device = torch.device("cuda:0" if use_gpu else "cpu")
window = train_options["window"]
niterations = train_options["niterations"]
batchsize = train_options["batchsize"]
style_encoder_opts = network_options["style_encoder"]
speech_encoder_opts = network_options["speech_encoder"]
decoder_opts = network_options["decoder"]
# ===============================================
# Load Details
# ===============================================
with open(path_data_definition, "r") as f:
details = json.load(f)
nlabels = len(details["label_names"])
bone_names = details["bone_names"]
parents = torch.LongTensor(details["parents"])
dt = details["dt"]
# ===============================================
# Load Data
# ===============================================
ds = SGDataset(
path_data_definition,
path_processed_data,
window,
style_encoding_type=train_options["style_encoding_type"],
example_window_length=style_encoder_opts["example_length"],
)
# Workaround: The number of workers should be 0 so that the example length can be changed dynamically
dl = DataLoader(ds, drop_last=True, batch_size=batchsize, shuffle=True, num_workers=0)
dimensions = ds.get_shapes()
(
audio_input_mean,
audio_input_std,
anim_input_mean,
anim_input_std,
anim_output_mean,
anim_output_std,
) = ds.get_means_stds(device)
# ===============================================
# Load or Resume Networks
# ===============================================
style_encoding_type = train_options["style_encoding_type"]
if style_encoding_type == "label":
style_encoding_size = nlabels
elif style_encoding_type == "example":
style_encoding_size = style_encoder_opts["style_encoding_size"]
path_network_speech_encoder_weights = models_dir / "speech_encoder.pt"
path_network_decoder_weights = models_dir / "decoder.pt"
path_network_style_encoder_weights = models_dir / "style_encoder.pt"
path_checkpoints = models_dir / "checkpoints.pt"
if (
train_options["resume"]
and os.path.exists(path_network_speech_encoder_weights)
and os.path.exists(path_network_decoder_weights)
and os.path.exists(path_checkpoints)
):
network_speech_encoder = torch.load(path_network_speech_encoder_weights).to(device)
network_decoder = torch.load(path_network_decoder_weights).to(device)
network_style_encoder = torch.load(path_network_style_encoder_weights).to(device)
else:
network_speech_encoder = SpeechEncoder(
dimensions["num_audio_features"],
speech_encoder_opts["nhidden"],
speech_encoder_opts["speech_encoding_size"],
).to(device)
network_decoder = Decoder(
pose_input_size=dimensions["pose_input_size"],
pose_output_size=dimensions["pose_output_size"],
speech_encoding_size=speech_encoder_opts["speech_encoding_size"],
style_encoding_size=style_encoding_size,
hidden_size=decoder_opts["nhidden"],
num_rnn_layers=2,
).to(device)
network_style_encoder = StyleEncoder(
dimensions["pose_input_size"],
style_encoder_opts["nhidden"],
style_encoding_size,
type=style_encoder_opts["type"],
use_vae=style_encoder_opts["use_vae"],
).to(device)
if use_script:
network_speech_encoder_script = torch.jit.script(network_speech_encoder)
network_decoder_script = torch.jit.script(network_decoder)
network_style_encoder_script = torch.jit.script(network_style_encoder)
else:
network_speech_encoder_script = network_speech_encoder
network_decoder_script = network_decoder
network_style_encoder_script = network_style_encoder
# ===============================================
# Optimizer
# ===============================================
all_parameters = (
list(network_speech_encoder.parameters())
+ list(network_decoder.parameters())
+ (list(network_style_encoder.parameters() if style_encoding_type == "example" else []))
)
optimizer = RAdam(all_parameters, lr=train_options["learning_rate"], eps=train_options["eps"])
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, train_options["learning_rate_decay"]
)
if train_options["resume"]:
checkpoints = torch.load(path_checkpoints)
iteration = checkpoints["iteration"]
epoch = checkpoints["epoch"]
loss = checkpoints["loss"]
optimizer.load_state_dict(checkpoints['optimizer_state_dict'])
else:
iteration = 0
epoch = 0
# ===============================================
# Setting Log Directories
# ===============================================
samples_dir = logs_dir / "samples"
samples_dir.mkdir(exist_ok=True)
if train_options["use_tensorboard"]:
tb_dir = logs_dir / "tb"
tb_dir.mkdir(exist_ok=True)
writer = SummaryWriter(tb_dir, flush_secs=10)
hparams = flatten_dict(network_options)
hparams.update(flatten_dict(train_options))
writer.add_hparams(hparams, {"No Metric": 0.0})
# ===============================================
# Begin Training
# ===============================================
while iteration < (1000 * niterations):
start_time = datetime.datetime.now()
for batch_index, batch in enumerate(dl):
network_speech_encoder.train()
network_decoder.train()
network_style_encoder.train()
(W_audio_features, W_root_pos, W_root_rot, W_root_vel, W_root_vrt, W_lpos, W_ltxy, W_lvel, W_lvrt, W_gaze_pos, WStyle) = batch
# (32, 256, 81), (32, 256, 3), (32, 256, 4), (32, 256, 3), (32, 256, 3), (32, 256, 75, 3), (32, 256, 75, 2, 3), (32, 256, 75, 3), (32, 256, 75, 3), (32, 256, 3), (32, 256, 1134)
W_audio_features = W_audio_features.to(device)
W_root_pos = W_root_pos.to(device)
W_root_rot = W_root_rot.to(device)
W_root_vel = W_root_vel.to(device)
W_root_vrt = W_root_vrt.to(device)
W_lpos = W_lpos.to(device)
W_ltxy = W_ltxy.to(device)
W_lvel = W_lvel.to(device)
W_lvrt = W_lvrt.to(device)
W_gaze_pos = W_gaze_pos.to(device)
WStyle = WStyle.to(device)
# Dynamically changing example length for the next iteration
ds.example_window_length = 2 * random.randint(style_encoder_opts["example_length"] // 2,
style_encoder_opts["example_length"])
# Speech Encoder
speech_encoding = network_speech_encoder_script(
(W_audio_features - audio_input_mean) / audio_input_std
)
# Style Encoder
if style_encoding_type == "example":
WStyle = (WStyle - anim_input_mean) / anim_input_std
style_encoding, mu, logvar = network_style_encoder_script(
WStyle.to(device=device)
)
else:
style_encoding = WStyle
# Gesture Generator
(
O_root_pos,
O_root_rot,
O_root_vel,
O_root_vrt,
O_lpos,
O_ltxy,
O_lvel,
O_lvrt,
) = network_decoder_script(
W_root_pos[:, 0],
W_root_rot[:, 0],
W_root_vel[:, 0],
W_root_vrt[:, 0],
W_lpos[:, 0],
W_ltxy[:, 0],
W_lvel[:, 0],
W_lvrt[:, 0],
W_gaze_pos,
speech_encoding,
style_encoding.unsqueeze(1).repeat((1, speech_encoding.shape[1], 1)),
parents,
anim_input_mean,
anim_input_std,
anim_output_mean,
anim_output_std,
dt,
)
# Compute Character/World Space
W_lmat = xform_orthogonalize_from_xy(W_ltxy)
O_lmat = xform_orthogonalize_from_xy(O_ltxy)
## Root Velocities to World Space
O_root_vel_1_ = quat_mul_vec(O_root_rot[:, :-1], O_root_vel[:, 1:])
O_root_vrt_1_ = quat_mul_vec(O_root_rot[:, :-1], O_root_vrt[:, 1:])
O_root_vel_0 = quat_mul_vec(O_root_rot[:, 0:1], O_root_vel[:, 0:1])
O_root_vrt_0 = quat_mul_vec(O_root_rot[:, 0:1], O_root_vrt[:, 0:1])
O_root_vel = torch.cat((O_root_vel_0, O_root_vel_1_), dim=1)
O_root_vrt = torch.cat((O_root_vrt_0, O_root_vrt_1_), dim=1)
W_root_vel_1_ = quat_mul_vec(W_root_rot[:, :-1], W_root_vel[:, 1:])
W_root_vrt_1_ = quat_mul_vec(W_root_rot[:, :-1], W_root_vrt[:, 1:])
W_root_vel_0 = quat_mul_vec(W_root_rot[:, 0:1], W_root_vel[:, 0:1])
W_root_vrt_0 = quat_mul_vec(W_root_rot[:, 0:1], W_root_vrt[:, 0:1])
W_root_vel = torch.cat((W_root_vel_0, W_root_vel_1_), dim=1)
W_root_vrt = torch.cat((W_root_vrt_0, W_root_vrt_1_), dim=1)
## Update First Joint
O_lpos_0 = quat_mul_vec(O_root_rot, O_lpos[:, :, 0]) + O_root_pos
O_lmat_0 = torch.matmul(quat_to_xform(O_root_rot), O_lmat[:, :, 0])
O_lvel_0 = (
O_root_vel
+ quat_mul_vec(O_root_rot, O_lvel[:, :, 0])
+ torch.cross(O_root_vrt, quat_mul_vec(O_root_rot, O_lpos[:, :, 0]))
)
O_lvrt_0 = O_root_vrt + quat_mul_vec(O_root_rot, O_lvrt[:, :, 0])
O_lpos = torch.cat((O_lpos_0.unsqueeze(2), O_lpos[:, :, 1:]), dim=2)
O_lmat = torch.cat((O_lmat_0.unsqueeze(2), O_lmat[:, :, 1:]), dim=2)
O_lvel = torch.cat((O_lvel_0.unsqueeze(2), O_lvel[:, :, 1:]), dim=2)
O_lvrt = torch.cat((O_lvrt_0.unsqueeze(2), O_lvrt[:, :, 1:]), dim=2)
W_lpos_0 = quat_mul_vec(W_root_rot, W_lpos[:, :, 0]) + W_root_pos
W_lmat_0 = torch.matmul(quat_to_xform(W_root_rot), W_lmat[:, :, 0])
W_lvel_0 = (
W_root_vel
+ quat_mul_vec(W_root_rot, W_lvel[:, :, 0])
+ torch.cross(W_root_vrt, quat_mul_vec(W_root_rot, W_lpos[:, :, 0]))
)
W_lvrt_0 = W_root_vrt + quat_mul_vec(W_root_rot, W_lvrt[:, :, 0])
W_lpos = torch.cat((W_lpos_0.unsqueeze(2), W_lpos[:, :, 1:]), dim=2)
W_lmat = torch.cat((W_lmat_0.unsqueeze(2), W_lmat[:, :, 1:]), dim=2)
W_lvel = torch.cat((W_lvel_0.unsqueeze(2), W_lvel[:, :, 1:]), dim=2)
W_lvrt = torch.cat((W_lvrt_0.unsqueeze(2), W_lvrt[:, :, 1:]), dim=2)
# Fk to Character or World Space
W_cmat, W_cpos, W_cvrt, W_cvel = xform_fk_vel(
W_lmat, W_lpos, W_lvrt, W_lvel, parents
)
O_cmat, O_cpos, O_cvrt, O_cvel = xform_fk_vel(
O_lmat, O_lpos, O_lvrt, O_lvel, parents
)
O_root_mat = quat_to_xform(O_root_rot)
W_root_mat = quat_to_xform(W_root_rot)
# Compute Gaze Dirs
W_gaze_dir = quat_inv_mul_vec(W_root_rot, normalize(W_gaze_pos - W_root_pos))
O_gaze_dir = quat_inv_mul_vec(O_root_rot, normalize(W_gaze_pos - O_root_pos))
# Compute Losses
loss_root_pos = torch.mean(torch.abs(0.1 * (O_root_pos - W_root_pos)))
loss_root_rot = torch.mean(torch.abs(10.0 * (O_root_mat - W_root_mat)))
loss_root_vel = torch.mean(torch.abs(0.1 * (O_root_vel - W_root_vel)))
loss_root_vrt = torch.mean(torch.abs(5.0 * (O_root_vrt - W_root_vrt)))
loss_lpos = torch.mean(torch.abs(15.0 * (O_lpos - W_lpos)))
loss_lrot = torch.mean(torch.abs(15.0 * (O_ltxy - W_ltxy)))
loss_lvel = torch.mean(torch.abs(10.0 * (O_lvel - W_lvel)))
loss_lvrt = torch.mean(torch.abs(7.0 * (O_lvrt - W_lvrt)))
loss_cpos = torch.mean(torch.abs(0.1 * (O_cpos - W_cpos)))
loss_crot = torch.mean(torch.abs(3.0 * (O_cmat - W_cmat)))
loss_cvel = torch.mean(torch.abs(0.06 * (O_cvel - W_cvel)))
loss_cvrt = torch.mean(torch.abs(1.25 * (O_cvrt - W_cvrt)))
loss_ldvl = torch.mean(
torch.abs(
7.0
* (
(O_lpos[:, 1:] - O_lpos[:, :-1]) / dt
- (W_lpos[:, 1:] - W_lpos[:, :-1]) / dt
)
)
)
loss_ldvt = torch.mean(
torch.abs(
8.0
* (
(O_ltxy[:, 1:] - O_ltxy[:, :-1]) / dt
- (W_ltxy[:, 1:] - W_ltxy[:, :-1]) / dt
)
)
)
loss_cdvl = torch.mean(
torch.abs(
0.06
* (
(O_cpos[:, 1:] - O_cpos[:, :-1]) / dt
- (W_cpos[:, 1:] - W_cpos[:, :-1]) / dt
)
)
)
loss_cdvt = torch.mean(
torch.abs(
1.25
* (
(O_cmat[:, 1:] - O_cmat[:, :-1]) / dt
- (W_cmat[:, 1:] - W_cmat[:, :-1]) / dt
)
)
)
loss_gaze = torch.mean(torch.abs(10.0 * (O_gaze_dir - W_gaze_dir)))
loss_kl_div = 0.0
if mu is not None and logvar is not None:
kl_div, kl_div_weight = compute_KL_div(mu, logvar, iteration)
loss_kl_div = kl_div_weight * torch.mean(kl_div)
loss = (
+loss_root_pos
+ loss_root_rot
+ loss_root_vel
+ loss_root_vrt
+ loss_lpos
+ loss_lrot
+ loss_lvel
+ loss_lvrt
+ loss_cpos
+ loss_crot
+ loss_cvel
+ loss_cvrt
+ loss_ldvl
+ loss_ldvt
+ loss_cdvl
+ loss_cdvt
+ loss_gaze
+ loss_kl_div
) / 18.0
# Backward
loss.backward()
optimizer.step()
# Zero Gradients
optimizer.zero_grad()
losses = loss.detach().item()
if (iteration + 1) % 1000 == 0:
scheduler.step()
# ===================================================
# Logging, Generating Samples
# ===================================================
if train_options["use_tensorboard"]:
writer.add_scalar("losses/total_loss", loss, iteration)
writer.add_scalars(
"losses/losses",
{
"loss_root_pos": loss_root_pos,
"loss_root_rot": loss_root_rot,
"loss_root_vel": loss_root_vel,
"loss_root_vrt": loss_root_vrt,
"loss_lpos": loss_lpos,
"loss_lrot": loss_lrot,
"loss_lvel": loss_lvel,
"loss_lvrt": loss_lvrt,
"loss_cpos": loss_cpos,
"loss_crot": loss_crot,
"loss_cvel": loss_cvel,
"loss_cvrt": loss_cvrt,
"loss_ldvl": loss_ldvl,
"loss_ldvt": loss_ldvt,
"loss_cdvl": loss_cdvl,
"loss_cdvt": loss_cdvt,
"loss_gaze": loss_gaze,
"loss_kl_div": loss_kl_div,
},
iteration,
)
if (iteration + 1) % 1 == 0:
sys.stdout.write(
"\r"
+ progress(
epoch,
iteration,
batch_index,
np.mean(losses),
(len(ds) // batchsize),
start_time,
)
)
if iteration % train_options["generate_samples_step"] == 0:
sys.stdout.write(
"\r| Saving Networks... |"
)
torch.save(network_speech_encoder, path_network_speech_encoder_weights)
torch.save(network_decoder, path_network_decoder_weights)
torch.save(network_style_encoder, path_network_style_encoder_weights)
torch.save({
'iteration': iteration,
"epoch": epoch,
'loss': loss,
'optimizer_state_dict': optimizer.state_dict(),
}, models_dir / "checkpoints.pt")
current_models_dir = models_dir / str(iteration)
current_models_dir.mkdir(exist_ok=True)
path_network_speech_encoder_weights_current = current_models_dir / "speech_encoder.pt"
path_network_decoder_weights_current = current_models_dir / "decoder.pt"
path_network_style_encoder_weights_current = current_models_dir / "style_encoder.pt"
torch.save(network_speech_encoder, path_network_speech_encoder_weights_current)
torch.save(network_decoder, path_network_decoder_weights_current)
torch.save(network_style_encoder, path_network_style_encoder_weights_current)
torch.save({
'iteration': iteration,
"epoch": epoch,
'loss': loss,
'optimizer_state_dict': optimizer.state_dict(),
}, current_models_dir / "checkpoints.pt")
with torch.no_grad():
network_speech_encoder.eval()
network_decoder.eval()
network_style_encoder.eval()
sys.stdout.write(
"\r| Generating Animation... |"
)
# Write training animation
for i in range(3):
(
S_audio_features,
S_root_pos,
S_root_rot,
S_root_vel,
S_root_vrt,
S_lpos,
S_ltxy,
S_lvel,
S_lvrt,
S_gaze_pos,
label,
se,
range_index,
) = ds.get_sample("train", 30)
speech_encoding = network_speech_encoder_script(
(S_audio_features.to(device=device) - audio_input_mean)
/ audio_input_std
)
if style_encoding_type == "example":
example = ds.get_example(se, se, ds.example_window_length)
example = (example.to(device=device) - anim_input_mean) / anim_input_std
style_encoding, _, _ = network_style_encoder_script(example[np.newaxis])
else:
style_encoding = np.zeros([nlabels])
style_encoding[label] = 1.0
style_encoding = torch.as_tensor(
style_encoding, dtype=torch.float32, device=device
)[np.newaxis]
(
V_root_pos,
V_root_rot,
_,
_,
V_lpos,
V_ltxy,
_,
_,
) = network_decoder_script(
S_root_pos[:, 0].to(device=device),
S_root_rot[:, 0].to(device=device),
S_root_vel[:, 0].to(device=device),
S_root_vrt[:, 0].to(device=device),
S_lpos[:, 0].to(device=device),
S_ltxy[:, 0].to(device=device),
S_lvel[:, 0].to(device=device),
S_lvrt[:, 0].to(device=device),
S_gaze_pos.to(device=device),
speech_encoding,
style_encoding.unsqueeze(1).repeat((1, speech_encoding.shape[1], 1)),
parents,
anim_input_mean,
anim_input_std,
anim_output_mean,
anim_output_std,
dt,
)
S_lrot = quat.from_xform(xform_orthogonalize_from_xy(S_ltxy).cpu().numpy())
V_lrot = quat.from_xform(xform_orthogonalize_from_xy(V_ltxy).cpu().numpy())
try:
current_label = details["label_names"][label]
write_bvh(
str(
samples_dir
/ (
f"iteration_{iteration}_train_ground_{i}_{current_label}.bvh"
)
),
S_root_pos[0].cpu().numpy(),
S_root_rot[0].cpu().numpy(),
S_lpos[0].cpu().numpy(),
S_lrot[0],
parents=parents.cpu().numpy(),
names=bone_names,
order="zyx",
dt=dt,
)
write_bvh(
str(
samples_dir
/ (
f"iteration_{iteration}_train_predict_{i}_{current_label}.bvh"
)
),
V_root_pos[0].cpu().numpy(),
V_root_rot[0].cpu().numpy(),
V_lpos[0].cpu().numpy(),
V_lrot[0],
parents=parents.cpu().numpy(),
names=bone_names,
order="zyx",
dt=dt,
)
except (PermissionError, OSError) as e:
print(e)
# Write validation animation
for i in range(3):
(
S_audio_features,
S_root_pos,
S_root_rot,
S_root_vel,
S_root_vrt,
S_lpos,
S_ltxy,
S_lvel,
S_lvrt,
S_gaze_pos,
label,
se,
range_index,
) = ds.get_sample("valid", 30)
speech_encoding = network_speech_encoder_script(
(S_audio_features.to(device=device) - audio_input_mean)
/ audio_input_std
)
if style_encoding_type == "example":
example = ds.get_example(se, se, ds.example_window_length)
example = (example.to(device=device) - anim_input_mean) / anim_input_std
style_encoding, _, _ = network_style_encoder_script(example[np.newaxis])
else:
style_encoding = np.zeros([nlabels])
style_encoding[label] = 1.0
style_encoding = torch.as_tensor(
style_encoding, dtype=torch.float32, device=device
)[np.newaxis]
(
V_root_pos,
V_root_rot,
_,
_,
V_lpos,
V_ltxy,
_,
_,
) = network_decoder_script(
S_root_pos[:, 0].to(device=device),
S_root_rot[:, 0].to(device=device),
S_root_vel[:, 0].to(device=device),
S_root_vrt[:, 0].to(device=device),
S_lpos[:, 0].to(device=device),
S_ltxy[:, 0].to(device=device),
S_lvel[:, 0].to(device=device),
S_lvrt[:, 0].to(device=device),
S_gaze_pos.to(device=device),
speech_encoding,
style_encoding.unsqueeze(1).repeat((1, speech_encoding.shape[1], 1)),
parents,
anim_input_mean,
anim_input_std,
anim_output_mean,
anim_output_std,
dt,
)
S_lrot = quat.from_xform(xform_orthogonalize_from_xy(S_ltxy).cpu().numpy())
V_lrot = quat.from_xform(xform_orthogonalize_from_xy(V_ltxy).cpu().numpy())
try:
current_label = details["label_names"][label]
write_bvh(
str(
samples_dir
/ (
f"iteration_{iteration}_valid_ground_{i}_{current_label}.bvh"
)
),
S_root_pos[0].cpu().numpy(),
S_root_rot[0].cpu().numpy(),
S_lpos[0].cpu().numpy(),
S_lrot[0],
parents=parents.cpu().numpy(),
names=bone_names,
order="zyx",
dt=dt,
)
write_bvh(
str(
samples_dir
/ (
f"iteration_{iteration}_valid_predict_{i}_{current_label}.bvh"
)
),
V_root_pos[0].cpu().numpy(),
V_root_rot[0].cpu().numpy(),
V_lpos[0].cpu().numpy(),
V_lrot[0],
parents=parents.cpu().numpy(),
names=bone_names,
order="zyx",
dt=dt,
)
except (PermissionError, OSError) as e:
print(e)
iteration += 1
sys.stdout.write("\n")
epoch += 1
print("Done!")
if __name__ == "__main__":
# For debugging
options = "../configs/configs_v1.json"
with open(options, "r") as f:
options = json.load(f)
train_options = options["train_opt"]
network_options = options["net_opt"]
paths = options["paths"]
base_path = Path(paths["base_path"])
path_processed_data = base_path / paths["path_processed_data"] / "processed_data.npz"
path_data_definition = base_path / paths["path_processed_data"] / "data_definition.json"
# Output directory
if paths["output_dir"] is None:
output_dir = (base_path / "outputs") / datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
output_dir.mkdir(exist_ok=True, parents=True)
paths["output_dir"] = str(output_dir)
else:
output_dir = Path(paths["output_dir"])
# Path to models
if paths["models_dir"] is None and not train_options["resume"]:
models_dir = output_dir / "saved_models"
models_dir.mkdir(exist_ok=True)
paths["models_dir"] = str(models_dir)
else:
models_dir = Path(paths["models_dir"])
# Log directory
logs_dir = output_dir / "logs"
logs_dir.mkdir(exist_ok=True)
options["paths"] = paths
with open(output_dir / 'options.json', 'w') as fp:
json.dump(options, fp, indent=4)
save_useful_info(output_dir)
train(
models_dir=models_dir,
logs_dir=logs_dir,
path_processed_data=path_processed_data,
path_data_definition=path_data_definition,
train_options=train_options,
network_options=network_options,
)
print("Done!")