ThaiVecFont / train.py
microhum's picture
add dockerfile
667ab99
raw
history blame
11.1 kB
import os
import random
import numpy as np
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, AdamW
from torchvision.utils import save_image
import wandb
from dataloader import get_loader
from models import util_funcs
from models.model_main import ModelMain
from options import get_parser_main_model
from data_utils.svg_utils import render
from time import time
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def train_main_model(opts):
setup_seed(opts.seed)
dir_exp = os.path.join(f"{opts.exp_path}", "experiments", opts.name_exp)
dir_sample = os.path.join(dir_exp, "samples")
dir_ckpt = os.path.join(dir_exp, "checkpoints")
dir_log = os.path.join(dir_exp, "logs")
logfile_train = open(os.path.join(dir_log, "train_loss_log.txt"), 'w')
logfile_val = open(os.path.join(dir_log, "val_loss_log.txt"), 'w')
train_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, opts.mode)
val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val')
run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project
text_table = wandb.Table(columns=["epoch", "loss", "ref"])
model_main = ModelMain(opts)
if torch.cuda.is_available() and opts.multi_gpu:
model_main = torch.nn.DataParallel(model_main)
if opts.continue_training:
model_main.load_state_dict(torch.load(opts.continue_ckpt)['model'])
model_main.cuda()
parameters_all = [{"params": model_main.img_encoder.parameters()}, {"params": model_main.img_decoder.parameters()},
{"params": model_main.modality_fusion.parameters()}, {"params": model_main.transformer_main.parameters()},
{"params": model_main.transformer_seqdec.parameters()}]
optimizer = AdamW(parameters_all, lr=opts.lr, betas=(opts.beta1, opts.beta2), eps=opts.eps, weight_decay=opts.weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
for epoch in range(opts.init_epoch, opts.n_epochs):
t0 = time()
for idx, data in enumerate(train_loader):
for key in data: data[key] = data[key].cuda()
ret_dict, loss_dict = model_main(data)
loss = opts.loss_w_l1 * loss_dict['img']['l1'] + opts.loss_w_pt_c * loss_dict['img']['vggpt'] + opts.kl_beta * loss_dict['kl'] \
+ loss_dict['svg']['total'] + loss_dict['svg_para']['total']
# perform optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
batches_done = epoch * len(train_loader) + idx + 1
message = (
f"Time: {'{} seconds'.format(time() - t0)}, "
f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
f"Loss: {loss.item():.6f}, "
f"img_l1_loss: {opts.loss_w_l1 * loss_dict['img']['l1'].item():.6f}, "
f"img_pt_c_loss: {opts.loss_w_pt_c * loss_dict['img']['vggpt']:.6f}, "
f"svg_total_loss: {loss_dict['svg']['total'].item():.6f}, "
f"svg_cmd_loss: {opts.loss_w_cmd * loss_dict['svg']['cmd'].item():.6f}, "
f"svg_args_loss: {opts.loss_w_args * loss_dict['svg']['args'].item():.6f}, "
f"svg_smooth_loss: {opts.loss_w_smt * loss_dict['svg']['smt'].item():.6f}, "
f"svg_aux_loss: {opts.loss_w_aux * loss_dict['svg']['aux'].item():.6f}, "
f"lr: {optimizer.param_groups[0]['lr']:.6f}, "
f"Step: {batches_done}"
)
if batches_done % opts.freq_log == 0:
logfile_train.write(message + '\n')
print(message)
if opts.wandb:
# print("Running With Wandb")
# Define the items for image and SVG losses
loss_img_items = ['l1', 'vggpt']
loss_svg_items = ['total', 'cmd', 'args', 'aux', 'smt']
# Log image loss items
for item in loss_img_items:
wandb.log({f'Loss/img_{item}': loss_dict['img'][item].item()}, step=batches_done)
# Log SVG loss items
for item in loss_svg_items:
wandb.log({f'Loss/svg_{item}': loss_dict['svg'][item].item()}, step=batches_done)
wandb.log({f'Loss/svg_para_{item}': loss_dict['svg_para'][item].item()}, step=batches_done)
# Log KL loss
wandb.log({'Loss/img_kl_loss': opts.kl_beta * loss_dict['kl'].item()}, step=batches_done)
wandb.log({
'Images/trg_img': wandb.Image(ret_dict['img']['trg'][0], caption="Target"),
'Images/img_output': wandb.Image(ret_dict['img']['out'][0], caption="Output")
}, step=batches_done)
text_table.add_data(epoch, loss, str(ret_dict['img']['ref'][0]))
wandb.log({"training_samples" : text_table})
if opts.freq_sample > 0 and batches_done % opts.freq_sample == 0:
img_sample = torch.cat((ret_dict['img']['trg'].data, ret_dict['img']['out'].data), -2)
save_file = os.path.join(dir_sample, f"train_epoch_{epoch}_batch_{batches_done}.png")
save_image(img_sample, save_file, nrow=8, normalize=True)
if opts.freq_val > 0 and batches_done % opts.freq_val == 0:
with torch.no_grad():
model_main.eval()
loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0},
'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}}
for val_idx, val_data in enumerate(val_loader):
for key in val_data: val_data[key] = val_data[key].cuda()
ret_dict_val, loss_dict_val = model_main(val_data, mode='val')
for loss_cat in ['img', 'svg']:
for key, _ in loss_val[loss_cat].items():
loss_val[loss_cat][key] += loss_dict_val[loss_cat][key]
for loss_cat in ['img', 'svg']:
for key, _ in loss_val[loss_cat].items():
loss_val[loss_cat][key] /= len(val_loader)
if opts.wandb:
for loss_cat in ['img', 'svg']:
# Iterate over keys and values in the loss dictionary
for key, value in loss_val[loss_cat].items():
# Log loss value to WandB
wandb.log({f'VAL/loss_{loss_cat}_{key}': value})
val_msg = (
f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, "
f"Val loss img l1: {loss_val['img']['l1']: .6f}, "
f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, "
f"Val loss total: {loss_val['svg']['total']: .6f}, "
f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, "
f"Val loss args: {loss_val['svg']['args']: .6f}, "
)
logfile_val.write(val_msg + "\n")
print(val_msg)
scheduler.step()
if epoch % opts.freq_ckpt == 0 and epoch >= opts.threshold_ckpt:
if opts.multi_gpu:
print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
torch.save({'model':model_main.module.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
else:
print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt")
torch.save({'model':model_main.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
if opts.wandb:
artifact = wandb.Artifact('model_main_checkpoints', type='model')
artifact.add_file(f'{dir_ckpt}/{epoch}_{batches_done}.ckpt')
run.log_artifact(artifact)
logfile_train.close()
logfile_val.close()
def backup_code(name_exp, exp_path):
os.makedirs(os.path.join(exp_path,'experiments', name_exp, 'code'), exist_ok=True)
shutil.copy('models/transformers.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'transformers.py') )
shutil.copy('models/model_main.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'model_main.py'))
shutil.copy('models/image_encoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_encoder.py'))
shutil.copy('models/image_decoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_decoder.py'))
shutil.copy('./train.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'train.py'))
shutil.copy('./options.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'options.py'))
def train(opts):
if opts.model_name == 'main_model':
train_main_model(opts)
elif opts.model_name == 'others':
train_others(opts)
else:
raise NotImplementedError
def main():
opts = get_parser_main_model().parse_args()
opts.name_exp = opts.name_exp + '_' + opts.model_name
os.makedirs(f"{opts.exp_path}/experiments", exist_ok=True)
debug = True
# Create directories
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp)
backup_code(opts.name_exp, opts.exp_path)
os.makedirs(experiment_dir, exist_ok=debug) # False to prevent multiple train run by mistake
os.makedirs(os.path.join(experiment_dir, "samples"), exist_ok=True)
os.makedirs(os.path.join(experiment_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(experiment_dir, "results"), exist_ok=True)
os.makedirs(os.path.join(experiment_dir, "logs"), exist_ok=True)
print(f"Training on experiment {opts.name_exp}...")
# Dump options
with open(os.path.join(experiment_dir, "opts.txt"), "w") as f:
for key, value in vars(opts).items():
f.write(str(key) + ": " + str(value) + "\n")
train(opts)
if __name__ == "__main__":
main()