Spaces:
Sleeping
Sleeping
import os | |
import random | |
import numpy as np | |
import shutil | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import Adam, AdamW | |
from torchvision.utils import save_image | |
import wandb | |
from dataloader import get_loader | |
from models import util_funcs | |
from models.model_main import ModelMain | |
from options import get_parser_main_model | |
from data_utils.svg_utils import render | |
from time import time | |
def setup_seed(seed): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def train_main_model(opts): | |
setup_seed(opts.seed) | |
dir_exp = os.path.join(f"{opts.exp_path}", "experiments", opts.name_exp) | |
dir_sample = os.path.join(dir_exp, "samples") | |
dir_ckpt = os.path.join(dir_exp, "checkpoints") | |
dir_log = os.path.join(dir_exp, "logs") | |
logfile_train = open(os.path.join(dir_log, "train_loss_log.txt"), 'w') | |
logfile_val = open(os.path.join(dir_log, "val_loss_log.txt"), 'w') | |
train_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size, opts.mode) | |
val_loader = get_loader(opts.data_root, opts.img_size, opts.language, opts.char_num, opts.max_seq_len, opts.dim_seq, opts.batch_size_val, 'val') | |
run = wandb.init(project=opts.wandb_project_name, config=opts) # initialize wandb project | |
text_table = wandb.Table(columns=["epoch", "loss", "ref"]) | |
model_main = ModelMain(opts) | |
if torch.cuda.is_available() and opts.multi_gpu: | |
model_main = torch.nn.DataParallel(model_main) | |
if opts.continue_training: | |
model_main.load_state_dict(torch.load(opts.continue_ckpt)['model']) | |
model_main.cuda() | |
parameters_all = [{"params": model_main.img_encoder.parameters()}, {"params": model_main.img_decoder.parameters()}, | |
{"params": model_main.modality_fusion.parameters()}, {"params": model_main.transformer_main.parameters()}, | |
{"params": model_main.transformer_seqdec.parameters()}] | |
optimizer = AdamW(parameters_all, lr=opts.lr, betas=(opts.beta1, opts.beta2), eps=opts.eps, weight_decay=opts.weight_decay) | |
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997) | |
for epoch in range(opts.init_epoch, opts.n_epochs): | |
t0 = time() | |
for idx, data in enumerate(train_loader): | |
for key in data: data[key] = data[key].cuda() | |
ret_dict, loss_dict = model_main(data) | |
loss = opts.loss_w_l1 * loss_dict['img']['l1'] + opts.loss_w_pt_c * loss_dict['img']['vggpt'] + opts.kl_beta * loss_dict['kl'] \ | |
+ loss_dict['svg']['total'] + loss_dict['svg_para']['total'] | |
# perform optimization | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
batches_done = epoch * len(train_loader) + idx + 1 | |
message = ( | |
f"Time: {'{} seconds'.format(time() - t0)}, " | |
f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, " | |
f"Loss: {loss.item():.6f}, " | |
f"img_l1_loss: {opts.loss_w_l1 * loss_dict['img']['l1'].item():.6f}, " | |
f"img_pt_c_loss: {opts.loss_w_pt_c * loss_dict['img']['vggpt']:.6f}, " | |
f"svg_total_loss: {loss_dict['svg']['total'].item():.6f}, " | |
f"svg_cmd_loss: {opts.loss_w_cmd * loss_dict['svg']['cmd'].item():.6f}, " | |
f"svg_args_loss: {opts.loss_w_args * loss_dict['svg']['args'].item():.6f}, " | |
f"svg_smooth_loss: {opts.loss_w_smt * loss_dict['svg']['smt'].item():.6f}, " | |
f"svg_aux_loss: {opts.loss_w_aux * loss_dict['svg']['aux'].item():.6f}, " | |
f"lr: {optimizer.param_groups[0]['lr']:.6f}, " | |
f"Step: {batches_done}" | |
) | |
if batches_done % opts.freq_log == 0: | |
logfile_train.write(message + '\n') | |
print(message) | |
if opts.wandb: | |
# print("Running With Wandb") | |
# Define the items for image and SVG losses | |
loss_img_items = ['l1', 'vggpt'] | |
loss_svg_items = ['total', 'cmd', 'args', 'aux', 'smt'] | |
# Log image loss items | |
for item in loss_img_items: | |
wandb.log({f'Loss/img_{item}': loss_dict['img'][item].item()}, step=batches_done) | |
# Log SVG loss items | |
for item in loss_svg_items: | |
wandb.log({f'Loss/svg_{item}': loss_dict['svg'][item].item()}, step=batches_done) | |
wandb.log({f'Loss/svg_para_{item}': loss_dict['svg_para'][item].item()}, step=batches_done) | |
# Log KL loss | |
wandb.log({'Loss/img_kl_loss': opts.kl_beta * loss_dict['kl'].item()}, step=batches_done) | |
wandb.log({ | |
'Images/trg_img': wandb.Image(ret_dict['img']['trg'][0], caption="Target"), | |
'Images/img_output': wandb.Image(ret_dict['img']['out'][0], caption="Output") | |
}, step=batches_done) | |
text_table.add_data(epoch, loss, str(ret_dict['img']['ref'][0])) | |
wandb.log({"training_samples" : text_table}) | |
if opts.freq_sample > 0 and batches_done % opts.freq_sample == 0: | |
img_sample = torch.cat((ret_dict['img']['trg'].data, ret_dict['img']['out'].data), -2) | |
save_file = os.path.join(dir_sample, f"train_epoch_{epoch}_batch_{batches_done}.png") | |
save_image(img_sample, save_file, nrow=8, normalize=True) | |
if opts.freq_val > 0 and batches_done % opts.freq_val == 0: | |
with torch.no_grad(): | |
model_main.eval() | |
loss_val = {'img':{'l1':0.0, 'vggpt':0.0}, 'svg':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}, | |
'svg_para':{'total':0.0, 'cmd':0.0, 'args':0.0, 'aux':0.0}} | |
for val_idx, val_data in enumerate(val_loader): | |
for key in val_data: val_data[key] = val_data[key].cuda() | |
ret_dict_val, loss_dict_val = model_main(val_data, mode='val') | |
for loss_cat in ['img', 'svg']: | |
for key, _ in loss_val[loss_cat].items(): | |
loss_val[loss_cat][key] += loss_dict_val[loss_cat][key] | |
for loss_cat in ['img', 'svg']: | |
for key, _ in loss_val[loss_cat].items(): | |
loss_val[loss_cat][key] /= len(val_loader) | |
if opts.wandb: | |
for loss_cat in ['img', 'svg']: | |
# Iterate over keys and values in the loss dictionary | |
for key, value in loss_val[loss_cat].items(): | |
# Log loss value to WandB | |
wandb.log({f'VAL/loss_{loss_cat}_{key}': value}) | |
val_msg = ( | |
f"Epoch: {epoch}/{opts.n_epochs}, Batch: {idx}/{len(train_loader)}, " | |
f"Val loss img l1: {loss_val['img']['l1']: .6f}, " | |
f"Val loss img pt: {loss_val['img']['vggpt']: .6f}, " | |
f"Val loss total: {loss_val['svg']['total']: .6f}, " | |
f"Val loss cmd: {loss_val['svg']['cmd']: .6f}, " | |
f"Val loss args: {loss_val['svg']['args']: .6f}, " | |
) | |
logfile_val.write(val_msg + "\n") | |
print(val_msg) | |
scheduler.step() | |
if epoch % opts.freq_ckpt == 0 and epoch >= opts.threshold_ckpt: | |
if opts.multi_gpu: | |
print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt") | |
torch.save({'model':model_main.module.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt') | |
else: | |
print(f"Saved {dir_ckpt}/{epoch}_{batches_done}.ckpt") | |
torch.save({'model':model_main.state_dict(), 'opt':optimizer.state_dict(), 'n_epoch':epoch, 'n_iter':batches_done}, f'{dir_ckpt}/{epoch}_{batches_done}.ckpt') | |
if opts.wandb: | |
artifact = wandb.Artifact('model_main_checkpoints', type='model') | |
artifact.add_file(f'{dir_ckpt}/{epoch}_{batches_done}.ckpt') | |
run.log_artifact(artifact) | |
logfile_train.close() | |
logfile_val.close() | |
def backup_code(name_exp, exp_path): | |
os.makedirs(os.path.join(exp_path,'experiments', name_exp, 'code'), exist_ok=True) | |
shutil.copy('models/transformers.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'transformers.py') ) | |
shutil.copy('models/model_main.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'model_main.py')) | |
shutil.copy('models/image_encoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_encoder.py')) | |
shutil.copy('models/image_decoder.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'image_decoder.py')) | |
shutil.copy('./train.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'train.py')) | |
shutil.copy('./options.py', os.path.join(exp_path,'experiments', name_exp, 'code', 'options.py')) | |
def train(opts): | |
if opts.model_name == 'main_model': | |
train_main_model(opts) | |
elif opts.model_name == 'others': | |
train_others(opts) | |
else: | |
raise NotImplementedError | |
def main(): | |
opts = get_parser_main_model().parse_args() | |
opts.name_exp = opts.name_exp + '_' + opts.model_name | |
os.makedirs(f"{opts.exp_path}/experiments", exist_ok=True) | |
debug = True | |
# Create directories | |
experiment_dir = os.path.join(f"{opts.exp_path}","experiments", opts.name_exp) | |
backup_code(opts.name_exp, opts.exp_path) | |
os.makedirs(experiment_dir, exist_ok=debug) # False to prevent multiple train run by mistake | |
os.makedirs(os.path.join(experiment_dir, "samples"), exist_ok=True) | |
os.makedirs(os.path.join(experiment_dir, "checkpoints"), exist_ok=True) | |
os.makedirs(os.path.join(experiment_dir, "results"), exist_ok=True) | |
os.makedirs(os.path.join(experiment_dir, "logs"), exist_ok=True) | |
print(f"Training on experiment {opts.name_exp}...") | |
# Dump options | |
with open(os.path.join(experiment_dir, "opts.txt"), "w") as f: | |
for key, value in vars(opts).items(): | |
f.write(str(key) + ": " + str(value) + "\n") | |
train(opts) | |
if __name__ == "__main__": | |
main() | |