|
import sys |
|
import os |
|
from os.path import join as pjoin |
|
from options.train_options import TrainOptions |
|
from utils.plot_script import * |
|
|
|
from models import build_models |
|
from utils.ema import ExponentialMovingAverage |
|
from trainers import DDPMTrainer |
|
from motion_loader import get_dataset_loader |
|
|
|
from accelerate.utils import set_seed |
|
from accelerate import Accelerator |
|
import torch |
|
|
|
import yaml |
|
from box import Box |
|
|
|
def yaml_to_box(yaml_file): |
|
with open(yaml_file, 'r') as file: |
|
yaml_data = yaml.safe_load(file) |
|
|
|
return Box(yaml_data) |
|
|
|
if __name__ == '__main__': |
|
accelerator = Accelerator() |
|
|
|
parser = TrainOptions() |
|
opt = parser.parse(accelerator) |
|
set_seed(opt.seed) |
|
torch.autograd.set_detect_anomaly(True) |
|
|
|
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) |
|
opt.model_dir = pjoin(opt.save_root, 'model') |
|
opt.meta_dir = pjoin(opt.save_root, 'meta') |
|
|
|
if opt.edit_mode: |
|
edit_config = yaml_to_box('options/edit.yaml') |
|
else: |
|
edit_config = yaml_to_box('options/noedit.yaml') |
|
|
|
if accelerator.is_main_process: |
|
os.makedirs(opt.model_dir, exist_ok=True) |
|
os.makedirs(opt.meta_dir, exist_ok=True) |
|
|
|
train_datasetloader = get_dataset_loader(opt, batch_size = opt.batch_size, split='train', accelerator=accelerator, mode='train') |
|
|
|
|
|
accelerator.print('\nInitializing model ...' ) |
|
encoder = build_models(opt, edit_config=edit_config) |
|
model_ema = None |
|
if opt.model_ema: |
|
|
|
|
|
adjust = 106_667 * opt.model_ema_steps / opt.num_train_steps |
|
alpha = 1.0 - opt.model_ema_decay |
|
alpha = min(1.0, alpha * adjust) |
|
print('EMA alpha:',alpha) |
|
model_ema = ExponentialMovingAverage(encoder, decay=1.0 - alpha) |
|
accelerator.print('Finish building Model.\n') |
|
|
|
trainer = DDPMTrainer(opt, encoder,accelerator, model_ema) |
|
|
|
trainer.train(train_datasetloader) |
|
|
|
|
|
|