File size: 41,374 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import gc
import json
import logging
import math
import os
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Mapping, Optional
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from training.optimizer import construct_optimizer
from training.utils.checkpoint_utils import (
assert_skipped_parameters_are_frozen,
exclude_params_matching_unix_pattern,
load_state_dict_into_model,
with_check_parameter_frozen,
)
from training.utils.data_utils import BatchedVideoDatapoint
from training.utils.distributed import all_reduce_max, barrier, get_rank
from training.utils.logger import Logger, setup_logging
from training.utils.train_utils import (
AverageMeter,
collect_dict_keys,
DurationMeter,
get_amp_type,
get_machine_local_and_dist_rank,
get_resume_checkpoint,
human_readable_time,
is_dist_avail_and_initialized,
log_env_variables,
makedir,
MemMeter,
Phase,
ProgressMeter,
set_seeds,
setup_distributed_backend,
)
CORE_LOSS_KEY = "core_loss"
def unwrap_ddp_if_wrapped(model):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
return model.module
return model
@dataclass
class OptimAMPConf:
enabled: bool = False
amp_dtype: str = "float16"
@dataclass
class OptimConf:
optimizer: torch.optim.Optimizer = None
options: Optional[Dict[str, Any]] = None
param_group_modifiers: Optional[List] = None
amp: Optional[Dict[str, Any]] = None
gradient_clip: Any = None
gradient_logger: Any = None
def __post_init__(self):
# amp
if not isinstance(self.amp, OptimAMPConf):
if self.amp is None:
self.amp = {}
assert isinstance(self.amp, Mapping)
self.amp = OptimAMPConf(**self.amp)
@dataclass
class DistributedConf:
backend: Optional[str] = None # inferred from accelerator type
comms_dtype: Optional[str] = None
find_unused_parameters: bool = False
timeout_mins: int = 30
@dataclass
class CudaConf:
cudnn_deterministic: bool = False
cudnn_benchmark: bool = True
allow_tf32: bool = False
# if not None, `matmul_allow_tf32` key will override `allow_tf32` for matmul
matmul_allow_tf32: Optional[bool] = None
# if not None, `cudnn_allow_tf32` key will override `allow_tf32` for cudnn
cudnn_allow_tf32: Optional[bool] = None
@dataclass
class CheckpointConf:
save_dir: str
save_freq: int
save_list: List[int] = field(default_factory=list)
model_weight_initializer: Any = None
save_best_meters: List[str] = None
skip_saving_parameters: List[str] = field(default_factory=list)
initialize_after_preemption: Optional[bool] = None
# if not None, training will be resumed from this checkpoint
resume_from: Optional[str] = None
def infer_missing(self):
if self.initialize_after_preemption is None:
with_skip_saving = len(self.skip_saving_parameters) > 0
self.initialize_after_preemption = with_skip_saving
return self
@dataclass
class LoggingConf:
log_dir: str
log_freq: int # In iterations
tensorboard_writer: Any
log_level_primary: str = "INFO"
log_level_secondary: str = "ERROR"
log_scalar_frequency: int = 100
log_visual_frequency: int = 100
scalar_keys_to_log: Optional[Dict[str, Any]] = None
log_batch_stats: bool = False
class Trainer:
"""
Trainer supporting the DDP training strategies.
"""
EPSILON = 1e-8
def __init__(
self,
*, # the order of these args can change at any time, so they are keyword-only
data: Dict[str, Any],
model: Dict[str, Any],
logging: Dict[str, Any],
checkpoint: Dict[str, Any],
max_epochs: int,
mode: str = "train",
accelerator: str = "cuda",
seed_value: int = 123,
val_epoch_freq: int = 1,
distributed: Dict[str, bool] = None,
cuda: Dict[str, bool] = None,
env_variables: Optional[Dict[str, Any]] = None,
optim: Optional[Dict[str, Any]] = None,
optim_overrides: Optional[List[Dict[str, Any]]] = None,
meters: Optional[Dict[str, Any]] = None,
loss: Optional[Dict[str, Any]] = None,
):
self._setup_env_variables(env_variables)
self._setup_timers()
self.data_conf = data
self.model_conf = model
self.logging_conf = LoggingConf(**logging)
self.checkpoint_conf = CheckpointConf(**checkpoint).infer_missing()
self.max_epochs = max_epochs
self.mode = mode
self.val_epoch_freq = val_epoch_freq
self.optim_conf = OptimConf(**optim) if optim is not None else None
self.meters_conf = meters
self.loss_conf = loss
distributed = DistributedConf(**distributed or {})
cuda = CudaConf(**cuda or {})
self.where = 0.0
self._infer_distributed_backend_if_none(distributed, accelerator)
self._setup_device(accelerator)
self._setup_torch_dist_and_backend(cuda, distributed)
makedir(self.logging_conf.log_dir)
setup_logging(
__name__,
output_dir=self.logging_conf.log_dir,
rank=self.rank,
log_level_primary=self.logging_conf.log_level_primary,
log_level_secondary=self.logging_conf.log_level_secondary,
)
set_seeds(seed_value, self.max_epochs, self.distributed_rank)
log_env_variables()
assert (
is_dist_avail_and_initialized()
), "Torch distributed needs to be initialized before calling the trainer."
self._setup_components() # Except Optimizer everything is setup here.
self._move_to_device()
self._construct_optimizers()
self._setup_dataloaders()
self.time_elapsed_meter = DurationMeter("Time Elapsed", self.device, ":.2f")
if self.checkpoint_conf.resume_from is not None:
assert os.path.exists(
self.checkpoint_conf.resume_from
), f"The 'resume_from' checkpoint {self.checkpoint_conf.resume_from} does not exist!"
dst = os.path.join(self.checkpoint_conf.save_dir, "checkpoint.pt")
if self.distributed_rank == 0 and not os.path.exists(dst):
# Copy the "resume_from" checkpoint to the checkpoint folder
# if there is not a checkpoint to resume from already there
makedir(self.checkpoint_conf.save_dir)
g_pathmgr.copy(self.checkpoint_conf.resume_from, dst)
barrier()
self.load_checkpoint()
self._setup_ddp_distributed_training(distributed, accelerator)
barrier()
def _setup_timers(self):
"""
Initializes counters for elapsed time and eta.
"""
self.start_time = time.time()
self.ckpt_time_elapsed = 0
self.est_epoch_time = dict.fromkeys([Phase.TRAIN, Phase.VAL], 0)
def _get_meters(self, phase_filters=None):
if self.meters is None:
return {}
meters = {}
for phase, phase_meters in self.meters.items():
if phase_filters is not None and phase not in phase_filters:
continue
for key, key_meters in phase_meters.items():
if key_meters is None:
continue
for name, meter in key_meters.items():
meters[f"{phase}_{key}/{name}"] = meter
return meters
def _infer_distributed_backend_if_none(self, distributed_conf, accelerator):
if distributed_conf.backend is None:
distributed_conf.backend = "nccl" if accelerator == "cuda" else "gloo"
def _setup_env_variables(self, env_variables_conf) -> None:
if env_variables_conf is not None:
for variable_name, value in env_variables_conf.items():
os.environ[variable_name] = value
def _setup_torch_dist_and_backend(self, cuda_conf, distributed_conf) -> None:
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = cuda_conf.cudnn_deterministic
torch.backends.cudnn.benchmark = cuda_conf.cudnn_benchmark
torch.backends.cuda.matmul.allow_tf32 = (
cuda_conf.matmul_allow_tf32
if cuda_conf.matmul_allow_tf32 is not None
else cuda_conf.allow_tf32
)
torch.backends.cudnn.allow_tf32 = (
cuda_conf.cudnn_allow_tf32
if cuda_conf.cudnn_allow_tf32 is not None
else cuda_conf.allow_tf32
)
self.rank = setup_distributed_backend(
distributed_conf.backend, distributed_conf.timeout_mins
)
def _setup_device(self, accelerator):
self.local_rank, self.distributed_rank = get_machine_local_and_dist_rank()
if accelerator == "cuda":
self.device = torch.device("cuda", self.local_rank)
torch.cuda.set_device(self.local_rank)
elif accelerator == "cpu":
self.device = torch.device("cpu")
else:
raise ValueError(f"Unsupported accelerator: {accelerator}")
def _setup_ddp_distributed_training(self, distributed_conf, accelerator):
assert isinstance(self.model, torch.nn.Module)
self.model = nn.parallel.DistributedDataParallel(
self.model,
device_ids=[self.local_rank] if accelerator == "cuda" else [],
find_unused_parameters=distributed_conf.find_unused_parameters,
)
if distributed_conf.comms_dtype is not None: # noqa
from torch.distributed.algorithms import ddp_comm_hooks
amp_type = get_amp_type(distributed_conf.comms_dtype)
if amp_type == torch.bfloat16:
hook = ddp_comm_hooks.default_hooks.bf16_compress_hook
logging.info("Enabling bfloat16 grad communication")
else:
hook = ddp_comm_hooks.default_hooks.fp16_compress_hook
logging.info("Enabling fp16 grad communication")
process_group = None
self.model.register_comm_hook(process_group, hook)
def _move_to_device(self):
logging.info(
f"Moving components to device {self.device} and local rank {self.local_rank}."
)
self.model.to(self.device)
logging.info(
f"Done moving components to device {self.device} and local rank {self.local_rank}."
)
def save_checkpoint(self, epoch, checkpoint_names=None):
checkpoint_folder = self.checkpoint_conf.save_dir
makedir(checkpoint_folder)
if checkpoint_names is None:
checkpoint_names = ["checkpoint"]
if (
self.checkpoint_conf.save_freq > 0
and (int(epoch) % self.checkpoint_conf.save_freq == 0)
) or int(epoch) in self.checkpoint_conf.save_list:
checkpoint_names.append(f"checkpoint_{int(epoch)}")
checkpoint_paths = []
for ckpt_name in checkpoint_names:
checkpoint_paths.append(os.path.join(checkpoint_folder, f"{ckpt_name}.pt"))
state_dict = unwrap_ddp_if_wrapped(self.model).state_dict()
state_dict = exclude_params_matching_unix_pattern(
patterns=self.checkpoint_conf.skip_saving_parameters, state_dict=state_dict
)
checkpoint = {
"model": state_dict,
"optimizer": self.optim.optimizer.state_dict(),
"epoch": epoch,
"loss": self.loss.state_dict(),
"steps": self.steps,
"time_elapsed": self.time_elapsed_meter.val,
"best_meter_values": self.best_meter_values,
}
if self.optim_conf.amp.enabled:
checkpoint["scaler"] = self.scaler.state_dict()
# DDP checkpoints are only saved on rank 0 (all workers are identical)
if self.distributed_rank != 0:
return
for checkpoint_path in checkpoint_paths:
self._save_checkpoint(checkpoint, checkpoint_path)
def _save_checkpoint(self, checkpoint, checkpoint_path):
"""
Save a checkpoint while guarding against the job being killed in the middle
of checkpoint saving (which corrupts the checkpoint file and ruins the
entire training since usually only the last checkpoint is kept per run).
We first save the new checkpoint to a temp file (with a '.tmp' suffix), and
and move it to overwrite the old checkpoint_path.
"""
checkpoint_path_tmp = f"{checkpoint_path}.tmp"
with g_pathmgr.open(checkpoint_path_tmp, "wb") as f:
torch.save(checkpoint, f)
# after torch.save is completed, replace the old checkpoint with the new one
if g_pathmgr.exists(checkpoint_path):
# remove the old checkpoint_path file first (otherwise g_pathmgr.mv fails)
g_pathmgr.rm(checkpoint_path)
success = g_pathmgr.mv(checkpoint_path_tmp, checkpoint_path)
assert success
def load_checkpoint(self):
ckpt_path = get_resume_checkpoint(self.checkpoint_conf.save_dir)
if ckpt_path is None:
self._init_model_state()
else:
if self.checkpoint_conf.initialize_after_preemption:
self._call_model_initializer()
self._load_resuming_checkpoint(ckpt_path)
def _init_model_state(self):
# Checking that parameters that won't be saved are indeed frozen
# We do this check here before even saving the model to catch errors
# are early as possible and not at the end of the first epoch
assert_skipped_parameters_are_frozen(
patterns=self.checkpoint_conf.skip_saving_parameters,
model=self.model,
)
# Checking that parameters that won't be saved are initialized from
# within the model definition, unless `initialize_after_preemption`
# is explicitly set to `True`. If not, this is a bug, and after
# preemption, the `skip_saving_parameters` will have random values
allow_init_skip_parameters = self.checkpoint_conf.initialize_after_preemption
with with_check_parameter_frozen(
patterns=self.checkpoint_conf.skip_saving_parameters,
model=self.model,
disabled=allow_init_skip_parameters,
):
self._call_model_initializer()
def _call_model_initializer(self):
model_weight_initializer = instantiate(
self.checkpoint_conf.model_weight_initializer
)
if model_weight_initializer is not None:
logging.info(
f"Loading pretrained checkpoint from {self.checkpoint_conf.model_weight_initializer}"
)
self.model = model_weight_initializer(model=self.model)
def _load_resuming_checkpoint(self, ckpt_path: str):
logging.info(f"Resuming training from {ckpt_path}")
with g_pathmgr.open(ckpt_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
load_state_dict_into_model(
model=self.model,
state_dict=checkpoint["model"],
ignore_missing_keys=self.checkpoint_conf.skip_saving_parameters,
)
self.optim.optimizer.load_state_dict(checkpoint["optimizer"])
self.loss.load_state_dict(checkpoint["loss"], strict=True)
self.epoch = checkpoint["epoch"]
self.steps = checkpoint["steps"]
self.ckpt_time_elapsed = checkpoint.get("time_elapsed")
if self.optim_conf.amp.enabled and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.best_meter_values = checkpoint.get("best_meter_values", {})
if "train_dataset" in checkpoint and self.train_dataset is not None:
self.train_dataset.load_checkpoint_state(checkpoint["train_dataset"])
def is_intermediate_val_epoch(self, epoch):
return epoch % self.val_epoch_freq == 0 and epoch < self.max_epochs - 1
def _step(
self,
batch: BatchedVideoDatapoint,
model: nn.Module,
phase: str,
):
outputs = model(batch)
targets = batch.masks
batch_size = len(batch.img_batch)
key = batch.dict_key # key for dataset
loss = self.loss[key](outputs, targets)
loss_str = f"Losses/{phase}_{key}_loss"
loss_log_str = os.path.join("Step_Losses", loss_str)
# loss contains multiple sub-components we wish to log
step_losses = {}
if isinstance(loss, dict):
step_losses.update(
{f"Losses/{phase}_{key}_{k}": v for k, v in loss.items()}
)
loss = self._log_loss_detailed_and_return_core_loss(
loss, loss_log_str, self.steps[phase]
)
if self.steps[phase] % self.logging_conf.log_scalar_frequency == 0:
self.logger.log(
loss_log_str,
loss,
self.steps[phase],
)
self.steps[phase] += 1
ret_tuple = {loss_str: loss}, batch_size, step_losses
if phase in self.meters and key in self.meters[phase]:
meters_dict = self.meters[phase][key]
if meters_dict is not None:
for _, meter in meters_dict.items():
meter.update(
find_stages=outputs,
find_metadatas=batch.metadata,
)
return ret_tuple
def run(self):
assert self.mode in ["train", "train_only", "val"]
if self.mode == "train":
if self.epoch > 0:
logging.info(f"Resuming training from epoch: {self.epoch}")
# resuming from a checkpoint
if self.is_intermediate_val_epoch(self.epoch - 1):
logging.info("Running previous val epoch")
self.epoch -= 1
self.run_val()
self.epoch += 1
self.run_train()
self.run_val()
elif self.mode == "val":
self.run_val()
elif self.mode == "train_only":
self.run_train()
def _setup_dataloaders(self):
self.train_dataset = None
self.val_dataset = None
if self.mode in ["train", "val"]:
self.val_dataset = instantiate(self.data_conf.get(Phase.VAL, None))
if self.mode in ["train", "train_only"]:
self.train_dataset = instantiate(self.data_conf.train)
def run_train(self):
while self.epoch < self.max_epochs:
dataloader = self.train_dataset.get_loader(epoch=int(self.epoch))
barrier()
outs = self.train_epoch(dataloader)
self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
# log train to text file.
if self.distributed_rank == 0:
with g_pathmgr.open(
os.path.join(self.logging_conf.log_dir, "train_stats.json"),
"a",
) as f:
f.write(json.dumps(outs) + "\n")
# Save checkpoint before validating
self.save_checkpoint(self.epoch + 1)
del dataloader
gc.collect()
# Run val, not running on last epoch since will run after the
# loop anyway
if self.is_intermediate_val_epoch(self.epoch):
self.run_val()
if self.distributed_rank == 0:
self.best_meter_values.update(self._get_trainer_state("train"))
with g_pathmgr.open(
os.path.join(self.logging_conf.log_dir, "best_stats.json"),
"a",
) as f:
f.write(json.dumps(self.best_meter_values) + "\n")
self.epoch += 1
# epoch was incremented in the loop but the val step runs out of the loop
self.epoch -= 1
def run_val(self):
if not self.val_dataset:
return
dataloader = self.val_dataset.get_loader(epoch=int(self.epoch))
outs = self.val_epoch(dataloader, phase=Phase.VAL)
del dataloader
gc.collect()
self.logger.log_dict(outs, self.epoch) # Logged only on rank 0
if self.distributed_rank == 0:
with g_pathmgr.open(
os.path.join(self.logging_conf.log_dir, "val_stats.json"),
"a",
) as f:
f.write(json.dumps(outs) + "\n")
def val_epoch(self, val_loader, phase):
batch_time = AverageMeter("Batch Time", self.device, ":.2f")
data_time = AverageMeter("Data Time", self.device, ":.2f")
mem = MemMeter("Mem (GB)", self.device, ":.2f")
iters_per_epoch = len(val_loader)
curr_phases = [phase]
curr_models = [self.model]
loss_names = []
for p in curr_phases:
for key in self.loss.keys():
loss_names.append(f"Losses/{p}_{key}_loss")
loss_mts = OrderedDict(
[(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
)
extra_loss_mts = {}
for model in curr_models:
model.eval()
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
progress = ProgressMeter(
iters_per_epoch,
[batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
self._get_meters(curr_phases),
prefix="Val Epoch: [{}]".format(self.epoch),
)
end = time.time()
for data_iter, batch in enumerate(val_loader):
# measure data loading time
data_time.update(time.time() - end)
batch = batch.to(self.device, non_blocking=True)
# compute output
with torch.no_grad():
with torch.cuda.amp.autocast(
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
dtype=(
get_amp_type(self.optim_conf.amp.amp_dtype)
if self.optim_conf
else None
),
):
for phase, model in zip(curr_phases, curr_models):
loss_dict, batch_size, extra_losses = self._step(
batch,
model,
phase,
)
assert len(loss_dict) == 1
loss_key, loss = loss_dict.popitem()
loss_mts[loss_key].update(loss.item(), batch_size)
for k, v in extra_losses.items():
if k not in extra_loss_mts:
extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
extra_loss_mts[k].update(v.item(), batch_size)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
self.time_elapsed_meter.update(
time.time() - self.start_time + self.ckpt_time_elapsed
)
if torch.cuda.is_available():
mem.update(reset_peak_usage=True)
if data_iter % self.logging_conf.log_freq == 0:
progress.display(data_iter)
if data_iter % self.logging_conf.log_scalar_frequency == 0:
# Log progress meters.
for progress_meter in progress.meters:
self.logger.log(
os.path.join("Step_Stats", phase, progress_meter.name),
progress_meter.val,
self.steps[Phase.VAL],
)
if data_iter % 10 == 0:
dist.barrier()
self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
self._log_timers(phase)
for model in curr_models:
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
for k, v in loss_mts.items():
out_dict[k] = v.avg
for k, v in extra_loss_mts.items():
out_dict[k] = v.avg
for phase in curr_phases:
out_dict.update(self._get_trainer_state(phase))
self._reset_meters(curr_phases)
logging.info(f"Meters: {out_dict}")
return out_dict
def _get_trainer_state(self, phase):
return {
"Trainer/where": self.where,
"Trainer/epoch": self.epoch,
f"Trainer/steps_{phase}": self.steps[phase],
}
def train_epoch(self, train_loader):
# Init stat meters
batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
data_times = []
phase = Phase.TRAIN
iters_per_epoch = len(train_loader)
loss_names = []
for batch_key in self.loss.keys():
loss_names.append(f"Losses/{phase}_{batch_key}_loss")
loss_mts = OrderedDict(
[(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
)
extra_loss_mts = {}
progress = ProgressMeter(
iters_per_epoch,
[
batch_time_meter,
data_time_meter,
mem_meter,
self.time_elapsed_meter,
*loss_mts.values(),
],
self._get_meters([phase]),
prefix="Train Epoch: [{}]".format(self.epoch),
)
# Model training loop
self.model.train()
end = time.time()
for data_iter, batch in enumerate(train_loader):
# measure data loading time
data_time_meter.update(time.time() - end)
data_times.append(data_time_meter.val)
batch = batch.to(
self.device, non_blocking=True
) # move tensors in a tensorclass
try:
self._run_step(batch, phase, loss_mts, extra_loss_mts)
# compute gradient and do optim step
exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
self.where = float(exact_epoch) / self.max_epochs
assert self.where <= 1 + self.EPSILON
if self.where < 1.0:
self.optim.step_schedulers(
self.where, step=int(exact_epoch * iters_per_epoch)
)
else:
logging.warning(
f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
)
# Log schedulers
if data_iter % self.logging_conf.log_scalar_frequency == 0:
for j, param_group in enumerate(self.optim.optimizer.param_groups):
for option in self.optim.schedulers[j]:
optim_prefix = (
"" + f"{j}_"
if len(self.optim.optimizer.param_groups) > 1
else ""
)
self.logger.log(
os.path.join("Optim", f"{optim_prefix}", option),
param_group[option],
self.steps[phase],
)
# Clipping gradients and detecting diverging gradients
if self.gradient_clipper is not None:
self.scaler.unscale_(self.optim.optimizer)
self.gradient_clipper(model=self.model)
if self.gradient_logger is not None:
self.gradient_logger(
self.model, rank=self.distributed_rank, where=self.where
)
# Optimizer step: the scaler will make sure gradients are not
# applied if the gradients are infinite
self.scaler.step(self.optim.optimizer)
self.scaler.update()
# measure elapsed time
batch_time_meter.update(time.time() - end)
end = time.time()
self.time_elapsed_meter.update(
time.time() - self.start_time + self.ckpt_time_elapsed
)
mem_meter.update(reset_peak_usage=True)
if data_iter % self.logging_conf.log_freq == 0:
progress.display(data_iter)
if data_iter % self.logging_conf.log_scalar_frequency == 0:
# Log progress meters.
for progress_meter in progress.meters:
self.logger.log(
os.path.join("Step_Stats", phase, progress_meter.name),
progress_meter.val,
self.steps[phase],
)
# Catching NaN/Inf errors in the loss
except FloatingPointError as e:
raise e
self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
self._log_timers(Phase.TRAIN)
self._log_sync_data_times(Phase.TRAIN, data_times)
out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
for k, v in loss_mts.items():
out_dict[k] = v.avg
for k, v in extra_loss_mts.items():
out_dict[k] = v.avg
out_dict.update(self._get_trainer_state(phase))
logging.info(f"Losses and meters: {out_dict}")
self._reset_meters([phase])
return out_dict
def _log_sync_data_times(self, phase, data_times):
data_times = all_reduce_max(torch.tensor(data_times)).tolist()
steps = range(self.steps[phase] - len(data_times), self.steps[phase])
for step, data_time in zip(steps, data_times):
if step % self.logging_conf.log_scalar_frequency == 0:
self.logger.log(
os.path.join("Step_Stats", phase, "Data Time Synced"),
data_time,
step,
)
def _run_step(
self,
batch: BatchedVideoDatapoint,
phase: str,
loss_mts: Dict[str, AverageMeter],
extra_loss_mts: Dict[str, AverageMeter],
raise_on_error: bool = True,
):
"""
Run the forward / backward
"""
# it's important to set grads to None, especially with Adam since 0
# grads will also update a model even if the step doesn't produce
# gradients
self.optim.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(
enabled=self.optim_conf.amp.enabled,
dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
):
loss_dict, batch_size, extra_losses = self._step(
batch,
self.model,
phase,
)
assert len(loss_dict) == 1
loss_key, loss = loss_dict.popitem()
if not math.isfinite(loss.item()):
error_msg = f"Loss is {loss.item()}, attempting to stop training"
logging.error(error_msg)
if raise_on_error:
raise FloatingPointError(error_msg)
else:
return
self.scaler.scale(loss).backward()
loss_mts[loss_key].update(loss.item(), batch_size)
for extra_loss_key, extra_loss in extra_losses.items():
if extra_loss_key not in extra_loss_mts:
extra_loss_mts[extra_loss_key] = AverageMeter(
extra_loss_key, self.device, ":.2e"
)
extra_loss_mts[extra_loss_key].update(extra_loss.item(), batch_size)
def _log_meters_and_save_best_ckpts(self, phases: List[str]):
logging.info("Synchronizing meters")
out_dict = {}
checkpoint_save_keys = []
for key, meter in self._get_meters(phases).items():
meter_output = meter.compute_synced()
is_better_check = getattr(meter, "is_better", None)
for meter_subkey, meter_value in meter_output.items():
out_dict[os.path.join("Meters_train", key, meter_subkey)] = meter_value
if is_better_check is None:
continue
tracked_meter_key = os.path.join(key, meter_subkey)
if tracked_meter_key not in self.best_meter_values or is_better_check(
meter_value,
self.best_meter_values[tracked_meter_key],
):
self.best_meter_values[tracked_meter_key] = meter_value
if (
self.checkpoint_conf.save_best_meters is not None
and key in self.checkpoint_conf.save_best_meters
):
checkpoint_save_keys.append(tracked_meter_key.replace("/", "_"))
if len(checkpoint_save_keys) > 0:
self.save_checkpoint(self.epoch + 1, checkpoint_save_keys)
return out_dict
def _log_timers(self, phase):
time_remaining = 0
epochs_remaining = self.max_epochs - self.epoch - 1
val_epochs_remaining = sum(
n % self.val_epoch_freq == 0 for n in range(self.epoch, self.max_epochs)
)
# Adding the guaranteed val run at the end if val_epoch_freq doesn't coincide with
# the end epoch.
if (self.max_epochs - 1) % self.val_epoch_freq != 0:
val_epochs_remaining += 1
# Remove the current val run from estimate
if phase == Phase.VAL:
val_epochs_remaining -= 1
time_remaining += (
epochs_remaining * self.est_epoch_time[Phase.TRAIN]
+ val_epochs_remaining * self.est_epoch_time[Phase.VAL]
)
self.logger.log(
os.path.join("Step_Stats", phase, self.time_elapsed_meter.name),
self.time_elapsed_meter.val,
self.steps[phase],
)
logging.info(f"Estimated time remaining: {human_readable_time(time_remaining)}")
def _reset_meters(self, phases: str) -> None:
for meter in self._get_meters(phases).values():
meter.reset()
def _check_val_key_match(self, val_keys, phase):
if val_keys is not None:
# Check if there are any duplicates
assert len(val_keys) == len(
set(val_keys)
), f"Duplicate keys in val datasets, keys: {val_keys}"
# Check that the keys match the meter keys
if self.meters_conf is not None and phase in self.meters_conf:
assert set(val_keys) == set(self.meters_conf[phase].keys()), (
f"Keys in val datasets do not match the keys in meters."
f"\nMissing in meters: {set(val_keys) - set(self.meters_conf[phase].keys())}"
f"\nMissing in val datasets: {set(self.meters_conf[phase].keys()) - set(val_keys)}"
)
if self.loss_conf is not None:
loss_keys = set(self.loss_conf.keys()) - set(["all"])
assert all([k in loss_keys for k in val_keys]), (
f"Keys in val datasets do not match the keys in losses."
f"\nMissing in losses: {set(val_keys) - loss_keys}"
f"\nMissing in val datasets: {loss_keys - set(val_keys)}"
)
def _setup_components(self):
# Get the keys for all the val datasets, if any
val_phase = Phase.VAL
val_keys = None
if self.data_conf.get(val_phase, None) is not None:
val_keys = collect_dict_keys(self.data_conf[val_phase])
# Additional checks on the sanity of the config for val datasets
self._check_val_key_match(val_keys, phase=val_phase)
logging.info("Setting up components: Model, loss, optim, meters etc.")
self.epoch = 0
self.steps = {Phase.TRAIN: 0, Phase.VAL: 0}
self.logger = Logger(self.logging_conf)
self.model = instantiate(self.model_conf, _convert_="all")
print_model_summary(self.model)
self.loss = None
if self.loss_conf:
self.loss = {
key: el # wrap_base_loss(el)
for (key, el) in instantiate(self.loss_conf, _convert_="all").items()
}
self.loss = nn.ModuleDict(self.loss)
self.meters = {}
self.best_meter_values = {}
if self.meters_conf:
self.meters = instantiate(self.meters_conf, _convert_="all")
self.scaler = torch.amp.GradScaler(
self.device,
enabled=self.optim_conf.amp.enabled if self.optim_conf else False,
)
self.gradient_clipper = (
instantiate(self.optim_conf.gradient_clip) if self.optim_conf else None
)
self.gradient_logger = (
instantiate(self.optim_conf.gradient_logger) if self.optim_conf else None
)
logging.info("Finished setting up components: Model, loss, optim, meters etc.")
def _construct_optimizers(self):
self.optim = construct_optimizer(
self.model,
self.optim_conf.optimizer,
self.optim_conf.options,
self.optim_conf.param_group_modifiers,
)
def _log_loss_detailed_and_return_core_loss(self, loss, loss_str, step):
core_loss = loss.pop(CORE_LOSS_KEY)
if step % self.logging_conf.log_scalar_frequency == 0:
for k in loss:
log_str = os.path.join(loss_str, k)
self.logger.log(log_str, loss[k], step)
return core_loss
def print_model_summary(model: torch.nn.Module, log_dir: str = ""):
"""
Prints the model and the number of parameters in the model.
# Multiple packages provide this info in a nice table format
# However, they need us to provide an `input` (as they also write down the output sizes)
# Our models are complex, and a single input is restrictive.
# https://github.com/sksq96/pytorch-summary
# https://github.com/nmhkahn/torchsummaryX
"""
if get_rank() != 0:
return
param_kwargs = {}
trainable_parameters = sum(
p.numel() for p in model.parameters(**param_kwargs) if p.requires_grad
)
total_parameters = sum(p.numel() for p in model.parameters(**param_kwargs))
non_trainable_parameters = total_parameters - trainable_parameters
logging.info("==" * 10)
logging.info(f"Summary for model {type(model)}")
logging.info(f"Model is {model}")
logging.info(f"\tTotal parameters {get_human_readable_count(total_parameters)}")
logging.info(
f"\tTrainable parameters {get_human_readable_count(trainable_parameters)}"
)
logging.info(
f"\tNon-Trainable parameters {get_human_readable_count(non_trainable_parameters)}"
)
logging.info("==" * 10)
if log_dir:
output_fpath = os.path.join(log_dir, "model.txt")
with g_pathmgr.open(output_fpath, "w") as f:
print(model, file=f)
PARAMETER_NUM_UNITS = [" ", "K", "M", "B", "T"]
def get_human_readable_count(number: int) -> str:
"""
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
Examples:
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1.2 K'
>>> get_human_readable_count(2e6) # (two million)
'2.0 M'
>>> get_human_readable_count(3e9) # (three billion)
'3.0 B'
>>> get_human_readable_count(4e14) # (four hundred trillion)
'400 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
Args:
number: a positive integer number
Return:
A string formatted according to the pattern described above.
"""
assert number >= 0
labels = PARAMETER_NUM_UNITS
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10**shift)
index = num_groups - 1
if index < 1 or number >= 100:
return f"{int(number):,d} {labels[index]}"
else:
return f"{number:,.1f} {labels[index]}"
|