import os import shutil import argparse import random import numpy as np from datetime import datetime from tqdm import tqdm import importlib import copy import librosa from pathlib import Path import json import time import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.nn.parallel import DistributedDataParallel as DDP import wandb from diffusers.optimization import get_scheduler from omegaconf import OmegaConf from emage_evaltools.mertic import FGD, BC, L1div, LVDFace, MSEFace from emage_utils.motion_io import beat_format_load, beat_format_save, MASK_DICT, recover_from_mask import emage_utils.rotation_conversions as rc from emage_utils import fast_render from emage_utils.motion_rep_transfer import get_motion_rep_numpy from models.emage_audio import EmageVQVAEConv, EmageVAEConv, EmageVQModel, EmageAudioModel # --------------------------------- train,val,test fn here --------------------------------- # def inference_fn(cfg, model, device, test_path, save_path, **kwargs): motion_vq = kwargs["motion_vq"] actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model actual_model.eval() test_list = [] for data_meta_path in test_path: test_list.extend(json.load(open(data_meta_path, "r"))) test_list = [item for item in test_list if item.get("mode") == "test"] seen_ids = set() test_list = [item for item in test_list if not (item["video_id"] in seen_ids or seen_ids.add(item["video_id"]))] save_list = [] start_time = time.time() total_length = 0 for test_file in tqdm(test_list, desc="Testing"): audio, _ = librosa.load(test_file["audio_path"], sr=cfg.audio_sr) audio = torch.from_numpy(audio).to(device).unsqueeze(0) speaker_id = torch.zeros(1,1).to(device).long() # motion seed motion_data = np.load(test_file["motion_path"], allow_pickle=True) poses = torch.from_numpy(motion_data["poses"]).unsqueeze(0).to(device).float() foot_contact = torch.from_numpy(np.load(test_file["motion_path"].replace("smplxflame_30", "footcontact").replace(".npz", ".npy"))).unsqueeze(0).to(device).float() trans = torch.from_numpy(motion_data["trans"]).unsqueeze(0).to(device).float() expression = torch.from_numpy(motion_data["expressions"]).unsqueeze(0).to(device).float() bs, t, _ = poses.shape poses_6d = rc.axis_angle_to_rotation_6d(poses.reshape(bs, t, -1, 3)).reshape(bs, t, -1) masked_motion = torch.cat([poses_6d, trans, foot_contact], dim=-1) # bs t 337 # reconstrcution check # latent_dict = motion_vq.map2latent(poses_6d, expression, tar_contact=foot_contact, tar_trans=trans) # face_latent = latent_dict["face"] # upper_latent = latent_dict["upper"] # lower_latent = latent_dict["lower"] # hands_latent = latent_dict["hands"] # face_index, upper_index, lower_index, hands_index = None, None, None, None latent_dict = actual_model.inference(audio, speaker_id, motion_vq, masked_motion=masked_motion) face_latent = latent_dict["rec_face"] if cfg.lf > 0 and cfg.cf == 0 else None upper_latent = latent_dict["rec_upper"] if cfg.lu > 0 and cfg.cu == 0 else None hands_latent = latent_dict["rec_hands"] if cfg.lh > 0 and cfg.ch == 0 else None lower_latent = latent_dict["rec_lower"] if cfg.ll > 0 and cfg.cl == 0 else None # print(latent_dict["rec_face"].shape,latent_dict["cls_upper"].shape) face_index = torch.max(F.log_softmax(latent_dict["cls_face"], dim=2), dim=2)[1] if cfg.cf > 0 else None upper_index = torch.max(F.log_softmax(latent_dict["cls_upper"], dim=2), dim=2)[1] if cfg.cu > 0 else None hands_index = torch.max(F.log_softmax(latent_dict["cls_hands"], dim=2), dim=2)[1] if cfg.ch > 0 else None lower_index = torch.max(F.log_softmax(latent_dict["cls_lower"], dim=2), dim=2)[1] if cfg.cl > 0 else None motion_all = motion_vq.decode( face_latent=face_latent, upper_latent=upper_latent, lower_latent=lower_latent, hands_latent=hands_latent, face_index=face_index, upper_index=upper_index, lower_index=lower_index, hands_index=hands_index, get_global_motion=True, ref_trans=trans[:,0]) motion_pred = motion_all["motion_axis_angle"] t = motion_pred.shape[1] motion_pred = motion_pred.cpu().numpy().reshape(t, -1) expression_pred = motion_all["expression"].cpu().numpy().reshape(t, -1) trans_pred = motion_all["trans"].cpu().numpy().reshape(t, -1) # print(motion_pred.shape, expression_pred.shape, trans_pred.shape) beat_format_save(os.path.join(save_path, f"{test_file['video_id']}_output.npz"), motion_pred, upsample=30//cfg.pose_fps, expressions=expression_pred, trans=trans_pred) save_list.append( { "audio_path": test_file["audio_path"], "motion_path": os.path.join(save_path, f"{test_file['video_id']}_output.npz"), "video_id": test_file["video_id"], } ) total_length+=t time_cost = time.time() - start_time print(f"\n cost {time_cost:.2f} seconds to generate {total_length / cfg.pose_fps:.2f} seconds of motion") return test_list, save_list def get_mask(mask, ratio): pass def get_rec_loss(motion_pred, motion_gt, lu, ll, lh, lf): rec_loss_upper = lu * F.mse_loss(motion_pred["rec_upper"], motion_gt["upper"]) rec_loss_lower = ll * F.mse_loss(motion_pred["rec_lower"], motion_gt["lower"]) rec_loss_hands = lh * F.mse_loss(motion_pred["rec_hands"], motion_gt["hands"]) rec_loss_face = lf * F.mse_loss(motion_pred["rec_face"], motion_gt["face"]) return rec_loss_upper+rec_loss_lower+rec_loss_hands+rec_loss_face def get_cls_loss(motion_pred, motion_gt, cu, cl, ch, cf, ClsFn): ClsFn = ClsFn.to(motion_pred["cls_upper"].device) pred_upper = F.log_softmax(motion_pred["cls_upper"], dim=2) pred_lower = F.log_softmax(motion_pred["cls_lower"], dim=2) pred_hands = F.log_softmax(motion_pred["cls_hands"], dim=2) pred_face = F.log_softmax(motion_pred["cls_face"], dim=2) pred_upper = pred_upper.permute(0, 2, 1) pred_lower = pred_lower.permute(0, 2, 1) pred_hands = pred_hands.permute(0, 2, 1) pred_face = pred_face.permute(0, 2, 1) cls_loss_upper = cu * ClsFn(pred_upper, motion_gt["upper"]) cls_loss_lower = cl * ClsFn(pred_lower, motion_gt["lower"]) cls_loss_hands = ch * ClsFn(pred_hands, motion_gt["hands"]) cls_loss_face = cf * ClsFn(pred_face, motion_gt["face"]) return cls_loss_upper+cls_loss_lower+cls_loss_hands+cls_loss_face def train_val_fn(cfg, batch, model, device, mode="train", **kwargs): if mode == "train": model.train() kwargs["optimizer"].zero_grad() else: model.eval() motion_vq = kwargs["motion_vq"] motion_gt = batch["motion"].to(device) audio = batch["audio"].to(device) expressions_gt = batch["expressions"].to(device) trans = batch["trans"].to(device) foot_contact = batch["foot_contact"].to(device) bs, t, jc = motion_gt.shape j = jc // 3 speaker_id = torch.zeros(bs,1).to(device).long() motion_gt = rc.axis_angle_to_rotation_6d(motion_gt.reshape(bs,t,j,3)).reshape(bs, t, j*6) latent_index_dict = motion_vq.map2index(motion_gt, expressions_gt, tar_contact = foot_contact, tar_trans = trans) latent_dict = motion_vq.map2latent(motion_gt, expressions_gt, tar_contact = foot_contact, tar_trans = trans) masked_motion = torch.cat([motion_gt, trans, foot_contact], dim=-1) # forward use audio mask = torch.ones_like(masked_motion).to(device) mask[:, :cfg.model.seed_frames] = 0 motion_pred = model(audio, speaker_id, masked_motion=masked_motion, mask=mask, use_audio=True) loss_dict = { "rec_seed": get_rec_loss(motion_pred, latent_dict, cfg.model.lu, cfg.model.ll, cfg.model.lh, cfg.model.lf), "cls_seed": get_cls_loss(motion_pred, latent_index_dict, cfg.model.cu, cfg.model.cl, cfg.model.ch, cfg.model.cf, kwargs["ClsFn"]), } # forward use randon mask and audio mask_ratio = (kwargs["iteration"]/135*400) * 0.95 + 0.05 mask = torch.rand(bs, t, cfg.model.pose_dims+3+4) < mask_ratio mask = mask.float().to(device) motion_pred_random_audio = model(audio, speaker_id, masked_motion=masked_motion, mask=mask, use_audio=True) loss_dict["rec_audio"] = get_rec_loss(motion_pred_random_audio, latent_dict, cfg.model.lu, cfg.model.ll, cfg.model.lh, cfg.model.lf) loss_dict["cls_audio"] = get_cls_loss(motion_pred_random_audio, latent_index_dict, cfg.model.cu, cfg.model.cl, cfg.model.ch, cfg.model.cf, kwargs["ClsFn"]) # forward use random mask motion_pred_random_mask = model(audio, speaker_id, masked_motion=masked_motion, mask=mask, use_audio=False) loss_dict["rec_mask"] = get_rec_loss(motion_pred_random_mask, latent_dict, cfg.model.lu, cfg.model.ll, cfg.model.lh, cfg.model.lf) loss_dict["cls_mask"] = get_cls_loss(motion_pred_random_mask, latent_index_dict, cfg.model.cu, cfg.model.cl, cfg.model.ch, cfg.model.cf, kwargs["ClsFn"]) all_loss = sum(loss_dict.values()) loss_dict["all"] = all_loss if mode == "train": if cfg.solver.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.solver.max_grad_norm) all_loss.backward() kwargs["optimizer"].step() kwargs["lr_scheduler"].step() if mode == "val": _, cls_face = torch.max(F.log_softmax(motion_pred["cls_face"], dim=2), dim=2) _, cls_upper = torch.max(F.log_softmax(motion_pred["cls_upper"], dim=2), dim=2) _, cls_hands = torch.max(F.log_softmax(motion_pred["cls_hands"], dim=2), dim=2) _, cls_lower = torch.max(F.log_softmax(motion_pred["cls_lower"], dim=2), dim=2) face_latent = motion_pred["rec_face"] if cfg.model.lf > 0 and cfg.model.cf == 0 else None upper_latent = motion_pred["rec_upper"] if cfg.model.lu > 0 and cfg.model.cu == 0 else None hands_latent = motion_pred["rec_hands"] if cfg.model.lh > 0 and cfg.model.ch == 0 else None lower_latent = motion_pred["rec_lower"] if cfg.model.ll > 0 and cfg.model.cl == 0 else None face_index = cls_face if cfg.model.cf > 0 else None upper_index = cls_upper if cfg.model.cu > 0 else None hands_index = cls_hands if cfg.model.ch > 0 else None lower_index = cls_lower if cfg.model.cl > 0 else None decode_dict = motion_vq.decode( face_latent=face_latent, upper_latent=upper_latent, lower_latent=lower_latent, hands_latent=hands_latent, face_index=face_index, upper_index=upper_index, lower_index=lower_index, hands_index=hands_index,) motion_pred_rot6d = decode_dict["all_motion4inference"][:, :, :-7] # cache feature for evaluation kwargs["fgd_evaluator"].update(motion_pred_rot6d, motion_gt) return loss_dict # --------------------------------- main train loop here --------------------------------- # def main(cfg): seed_everything(cfg.seed) os.environ["WANDB_API_KEY"] = cfg.wandb_key local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0 torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) torch.distributed.init_process_group(backend="nccl") log_dir = os.path.join(cfg.output_dir, cfg.exp_name) experiment_ckpt_dir = os.path.join(log_dir, "checkpoints") os.makedirs(experiment_ckpt_dir, exist_ok=True) if local_rank == 0 and cfg.validation.wandb: run_time = datetime.now().strftime("%Y%m%d-%H%M") wandb.init( project=cfg.wandb_project, name=f"{cfg.exp_name}_{run_time}", entity=cfg.wandb_entity, dir=log_dir, config=OmegaConf.to_container(cfg) ) # init face_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/face").to(device) upper_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/upper").to(device) lower_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/lower").to(device) hands_motion_vq = EmageVQVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/hands").to(device) global_motion_ae = EmageVAEConv.from_pretrained("H-Liu1997/emage_audio", subfolder="emage_vq/global").to(device) motion_vq = EmageVQModel( face_model=face_motion_vq, upper_model=upper_motion_vq, lower_model=lower_motion_vq, hands_model=hands_motion_vq, global_model=global_motion_ae).to(device) for param in motion_vq.parameters(): param.requires_grad = False motion_vq.eval() if cfg.test: model = EmageAudioModel.from_pretrained("/content/drive/MyDrive/weights/emage3/best").to(device) else: model = init_hf_class(cfg.model.name_pyfile, cfg.model.class_name, cfg.model).to(device) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) for name, param in model.named_parameters(): param.requires_grad = True model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, broadcast_buffers=False) # optimizer optimizer_cls = torch.optim.Adam optimizer = optimizer_cls( filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.solver.learning_rate, betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), weight_decay=cfg.solver.adam_weight_decay, eps=cfg.solver.adam_epsilon ) lr_scheduler = get_scheduler( cfg.solver.lr_scheduler, optimizer=optimizer, num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps, num_training_steps=cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps ) # loss ClsFn = nn.NLLLoss() # dataset train_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='train') test_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='test') train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) train_loader = DataLoader(train_dataset, batch_size=cfg.data.train_bs, sampler=train_sampler, drop_last=True, num_workers=8) test_loader = DataLoader(test_dataset, batch_size=cfg.data.train_bs, sampler=test_sampler, drop_last=False, num_workers=8) # resume if cfg.resume_from_checkpoint: checkpoint = torch.load(cfg.resume_from_checkpoint, map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) iteration = checkpoint["iteration"] else: iteration = 0 if cfg.test: iteration = 0 max_epochs = (cfg.solver.max_train_steps // len(train_loader)) + (1 if cfg.solver.max_train_steps % len(train_loader) != 0 else 0) start_epoch = iteration // len(train_loader) start_step_in_epoch = iteration % len(train_loader) fgd_evaluator = FGD(download_path="./emage_evaltools/") bc_evaluator = BC(download_path="./emage_evaltools/", sigma=0.3, order=7) l1div_evaluator= L1div() lvd_evaluator = LVDFace() mse_evaluator = MSEFace() loss_meters = {} loss_meters_val = {} best_fgd_val = np.inf best_fgd_iteration_val= 0 best_fgd_test = np.inf best_fgd_iteration_test = 0 # train loop epoch = start_epoch while iteration < cfg.solver.max_train_steps: train_sampler.set_epoch(epoch) data_start = time.time() pbar = tqdm(train_loader, leave=True) for i, batch in enumerate(pbar): # for correct resume, if the dataset is very large. since we fixed the seed, we can skip the data if i < start_step_in_epoch: iteration += 1 continue # test if iteration % cfg.validation.test_steps == 0 and local_rank == 0: test_save_path = os.path.join(log_dir, f"test_{iteration}") os.makedirs(test_save_path, exist_ok=True) with torch.no_grad(): test_list, save_list = inference_fn(cfg.model, model, device, cfg.data.test_meta_paths, test_save_path, motion_vq=motion_vq) if cfg.validation.evaluation: metrics = evaluation_fn([True]*55, test_list, save_list, fgd_evaluator, bc_evaluator, l1div_evaluator, device, lvd_evaluator, mse_evaluator) if cfg.validation.visualization: visualization_fn(save_list, test_save_path, test_list, only_check_one=True) if cfg.validation.evaluation: best_fgd_test, best_fgd_iteration_test = log_test(model, metrics, iteration, best_fgd_test, best_fgd_iteration_test, cfg, local_rank, experiment_ckpt_dir, test_save_path) if cfg.test: return 0 # validation if iteration % cfg.validation.validation_steps == 0: loss_meters = {} loss_meters_val = {} fgd_evaluator.reset() pbar_val = tqdm(test_loader, leave=True) data_start_val = time.time() for j, batch in enumerate(pbar_val): data_time_val = time.time() - data_start_val with torch.no_grad(): val_loss_dict = train_val_fn(cfg, batch, model, device, mode="val", fgd_evaluator=fgd_evaluator, motion_vq=motion_vq, ClsFn=ClsFn, iteration=iteration) net_time_val = time.time() - data_start_val val_loss_dict["fgd"] = fgd_evaluator.compute() if j == len(test_loader) - 1 else 0 log_train_val(cfg, val_loss_dict, local_rank, loss_meters_val, pbar_val, epoch, max_epochs, iteration, net_time_val, data_time_val, optimizer, "Val ") data_start_val = time.time() if cfg.debug and j > 1: break if local_rank == 0: best_fgd_val, best_fgd_iteration_val = save_last_and_best_ckpt( model, optimizer, lr_scheduler, iteration, experiment_ckpt_dir, best_fgd_val, best_fgd_iteration_val, val_loss_dict["fgd"], lower_is_better=True, mertic_name="fgd") # train data_time = time.time() - data_start loss_dict = train_val_fn(cfg, batch, model, device, mode="train", motion_vq=motion_vq, optimizer=optimizer, lr_scheduler=lr_scheduler, ClsFn=ClsFn, iteration=iteration) net_time = time.time() - data_start - data_time log_train_val(cfg, loss_dict, local_rank, loss_meters, pbar, epoch, max_epochs, iteration, net_time, data_time, optimizer, "Train") data_start = time.time() iteration += 1 start_step_in_epoch = 0 epoch += 1 if local_rank == 0 and cfg.validation.wandb: wandb.finish() torch.distributed.destroy_process_group() # --------------------------------- utils fn here --------------------------------- # def evaluation_fn(joint_mask, gt_list, pred_list, fgd_evaluator, bc_evaluator, l1_evaluator, device, lvd_evaluator, mse_evaluator): fgd_evaluator.reset() bc_evaluator.reset() l1_evaluator.reset() lvd_evaluator.reset() mse_evaluator.reset() for test_file in tqdm(gt_list, desc="Evaluation"): # only load selective joints pred_file = [item for item in pred_list if item["video_id"] == test_file["video_id"]][0] if not pred_file: print(f"Missing prediction for {test_file['video_id']}") continue # print(test_file["motion_path"], pred_file["motion_path"]) gt_dict = beat_format_load(test_file["motion_path"], joint_mask) pred_dict = beat_format_load(pred_file["motion_path"], joint_mask) motion_gt = gt_dict["poses"] motion_pred = pred_dict["poses"] expressions_gt = gt_dict["expressions"] expressions_pred = pred_dict["expressions"] betas = gt_dict["betas"] # motion_gt = recover_from_mask(motion_gt, joint_mask) # t1*165 # motion_pred = recover_from_mask(motion_pred, joint_mask) # t2*165 t = min(motion_gt.shape[0], motion_pred.shape[0]) motion_gt = motion_gt[:t] motion_pred = motion_pred[:t] expressions_gt = expressions_gt[:t] expressions_pred = expressions_pred[:t] # bc and l1 require position representation motion_position_pred = get_motion_rep_numpy(motion_pred, device=device, betas=betas)["position"] # t*55*3 motion_position_pred = motion_position_pred.reshape(t, -1) # ignore the start and end 2s, this may for beat dataset only audio_beat = bc_evaluator.load_audio(test_file["audio_path"], t_start=2 * 16000, t_end=int((t-60)/30*16000)) motion_beat = bc_evaluator.load_motion(motion_position_pred, t_start=60, t_end=t-60, pose_fps=30, without_file=True) bc_evaluator.compute(audio_beat, motion_beat, length=t-120, pose_fps=30) # audio_beat = bc_evaluator.load_audio(test_file["audio_path"], t_start=0 * 16000, t_end=int((t-0)/30*16000)) # motion_beat = bc_evaluator.load_motion(motion_position_pred, t_start=0, t_end=t-0, pose_fps=30, without_file=True) # bc_evaluator.compute(audio_beat, motion_beat, length=t-0, pose_fps=30) l1_evaluator.compute(motion_position_pred) face_position_pred = get_motion_rep_numpy(motion_pred, device=device, expressions=expressions_pred, expression_only=True, betas=betas)["vertices"] # t -1 face_position_gt = get_motion_rep_numpy(motion_gt, device=device, expressions=expressions_gt, expression_only=True, betas=betas)["vertices"] lvd_evaluator.compute(face_position_pred, face_position_gt) mse_evaluator.compute(face_position_pred, face_position_gt) # fgd requires rotation 6d representaiton motion_gt = torch.from_numpy(motion_gt).to(device).unsqueeze(0) motion_pred = torch.from_numpy(motion_pred).to(device).unsqueeze(0) motion_gt = rc.axis_angle_to_rotation_6d(motion_gt.reshape(1, t, 55, 3)).reshape(1, t, 55*6) motion_pred = rc.axis_angle_to_rotation_6d(motion_pred.reshape(1, t, 55, 3)).reshape(1, t, 55*6) fgd_evaluator.update(motion_pred.float(), motion_gt.float()) metrics = {} metrics["fgd"] = fgd_evaluator.compute() metrics["bc"] = bc_evaluator.avg() metrics["l1"] = l1_evaluator.avg() metrics["lvd"] = lvd_evaluator.avg() metrics["mse"] = mse_evaluator.avg() return metrics def visualization_fn(pred_list, save_path, gt_list=None, only_check_one=True): if gt_list is None: # single visualization for i in range(len(pred_list)): fast_render.render_one_sequence( pred_list[i]["motion_path"], save_path, pred_list[i]["audio_path"], model_folder="./evaluation/smplx_models/", ) if only_check_one: break else: # paired visualization, pad the translation for i in range(len(pred_list)): npz_pred = np.load(pred_list[i]["motion_path"], allow_pickle=True) gt_file = [item for item in gt_list if item["video_id"] == pred_list[i]["video_id"]][0] if not gt_file: print(f"Missing prediction for {pred_list[i]['video_id']}") continue npz_gt = np.load(gt_file["motion_path"], allow_pickle=True) t = npz_gt["poses"].shape[0] np.savez( os.path.join(save_path, f"{pred_list[i]['video_id']}_transpad.npz"), betas=npz_pred['betas'][:t], poses=npz_pred['poses'][:t], expressions=npz_pred['expressions'][:t], trans=npz_pred["trans"][:t], model='smplx2020', gender='neutral', mocap_frame_rate=30, ) fast_render.render_one_sequence( os.path.join(save_path, f"{pred_list[i]['video_id']}_transpad.npz"), gt_file["motion_path"], save_path, pred_list[i]["audio_path"], model_folder="./evaluation/smplx_models/", ) if only_check_one: break def log_test(model, metrics, iteration, best_mertics, best_iteration, cfg, local_rank, experiment_ckpt_dir, video_save_path=None): if local_rank == 0: print(f"\n Test Results at iteration {iteration}:") for key, value in metrics.items(): print(f" {key}: {value:.10f}") if cfg.validation.wandb: for key, value in metrics.items(): wandb.log({f"test/{key}": value}, step=iteration) if cfg.validation.wandb and cfg.validation.visualization: videos_to_log = [] for filename in os.listdir(video_save_path): if filename.endswith(".mp4"): videos_to_log.append(wandb.Video(os.path.join(video_save_path, filename))) if videos_to_log: wandb.log({"test/videos": videos_to_log}, step=iteration) if metrics["fgd"] < best_mertics: best_mertics = metrics["fgd"] best_iteration = iteration model.module.save_pretrained(os.path.join(experiment_ckpt_dir, "test_best")) # print(metrics, best_mertics, best_iteration) message = f"Current Test FGD: {metrics['fgd']:.4f} (Best: {best_mertics:.4f} at iteration {best_iteration})" log_metric_with_box(message) return best_mertics, best_iteration def log_metric_with_box(message): box_width = len(message) + 2 border = "-" * box_width print(f"\n{border}") print(f"|{message}|") print(f"{border}\n") def log_train_val(cfg, loss_dict, local_rank, loss_meters, pbar, epoch, max_epochs, iteration, net_time, data_time, optimizer, ptype="Train"): new_loss_dict = {} for k, v in loss_dict.items(): if "fgd" in k: continue v_cpu = torch.as_tensor(v).float().cpu().item() if k not in loss_meters: loss_meters[k] = {"sum":0,"count":0} loss_meters[k]["sum"] += v_cpu loss_meters[k]["count"] += 1 new_loss_dict[k] = v_cpu mem_used = torch.cuda.memory_reserved() / 1E9 lr = optimizer.param_groups[0]["lr"] loss_str = " ".join([f"{k}: {new_loss_dict[k]:.4f}({loss_meters[k]['sum']/loss_meters[k]['count']:.4f})" for k in new_loss_dict]) desc = f"{ptype}: Epoch[{epoch}/{max_epochs}] Iter[{iteration}] {loss_str} lr: {lr:.2E} data_time: {data_time:.3f} net_time: {net_time:.3f} mem: {mem_used:.2f}GB" pbar.set_description(desc) pbar.bar_format = "{desc} {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" if cfg.validation.wandb and local_rank == 0: for k, v in new_loss_dict.items(): wandb.log({f"loss/{ptype}/{k}": v}, step=iteration) def save_last_and_best_ckpt(model, optimizer, lr_scheduler, iteration, save_dir, previous_best, best_iteration, current, lower_is_better=True, mertic_name="fgd"): checkpoint = { "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "lr_scheduler_state_dict": lr_scheduler.state_dict(), "iteration": iteration, } torch.save(checkpoint, os.path.join(save_dir, "last.bin")) model.module.save_pretrained(os.path.join(save_dir, "last")) if (lower_is_better and current < previous_best) or (not lower_is_better and current > previous_best): previous_best = current best_iteration = iteration shutil.copy(os.path.join(save_dir, "last.bin"), os.path.join(save_dir, "best.bin")) model.module.save_pretrained(os.path.join(save_dir, "best")) message = f"Current interation {iteration} {mertic_name}: {current:.4f} (Best: {previous_best:.4f} at iteration {best_iteration})" log_metric_with_box(message) return previous_best, best_iteration def init_hf_class(module_name, class_name, config, **kwargs): module = importlib.import_module(module_name) model_class = getattr(module, class_name) config_class = model_class.config_class config = config_class(config_obj=config) instance = model_class(config, **kwargs) return instance def init_class(module_name, class_name, config, **kwargs): module = importlib.import_module(module_name) model_class = getattr(module, class_name) instance = model_class(config, **kwargs) return instance def seed_everything(seed): os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True def init_env(): parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="./configs/train/stage2.yaml") parser.add_argument("--debug", action="store_true") parser.add_argument("--wandb", action="store_true") parser.add_argument("--visualization", action="store_true") parser.add_argument("--evaluation", action="store_true") parser.add_argument("--test", action="store_true") parser.add_argument('overrides', nargs=argparse.REMAINDER) args = parser.parse_args() config = OmegaConf.load(args.config) config.exp_name = os.path.splitext(os.path.basename(args.config))[0] if args.overrides: config = OmegaConf.merge(config, OmegaConf.from_dotlist(args.overrides)) if args.debug: config.wandb_project = "debug" config.exp_name = "debug" config.solver.max_train_steps = 4 else: run_time = datetime.now().strftime("%Y%m%d-%H%M") config.exp_name = config.exp_name + "_" + run_time if args.wandb: config.validation.wandb = True if args.visualization: config.validation.visualization = True if args.evaluation: config.validation.evaluation = True if args.test: config.test = True save_dir = os.path.join(config.output_dir, config.exp_name) os.makedirs(save_dir, exist_ok=True) sanity_check_dir = os.path.join(save_dir, 'sanity_check') os.makedirs(sanity_check_dir, exist_ok=True) with open(os.path.join(sanity_check_dir, f'{config.exp_name}.yaml'), 'w') as f: OmegaConf.save(config, f) current_dir = Path.cwd() for py_file in current_dir.rglob('*.py'): dest_path = Path(sanity_check_dir) / py_file.relative_to(current_dir) dest_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(py_file, dest_path) return config if __name__ == "__main__": config = init_env() main(config)