File size: 15,866 Bytes
5231633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 |
import os
import yaml
import torch
from torch import nn
import wandb
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from torch.distributed import init_process_group, destroy_process_group, barrier
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType
)
from .utils import Base, EXPECTED, EXPECTED_TRAIN
from .utils import create_folder_if_necessary, safe_save, load_or_fail
# pylint: disable=unused-argument
class WarpCore(ABC):
@dataclass(frozen=True)
class Config(Base):
experiment_id: str = EXPECTED_TRAIN
checkpoint_path: str = EXPECTED_TRAIN
output_path: str = EXPECTED_TRAIN
checkpoint_extension: str = "safetensors"
dist_file_subfolder: str = ""
allow_tf32: bool = True
wandb_project: str = None
wandb_entity: str = None
@dataclass() # not frozen, means that fields are mutable
class Info(): # not inheriting from Base, because we don't want to enforce the default fields
wandb_run_id: str = None
total_steps: int = 0
iter: int = 0
@dataclass(frozen=True)
class Data(Base):
dataset: Dataset = EXPECTED
dataloader: DataLoader = EXPECTED
iterator: any = EXPECTED
@dataclass(frozen=True)
class Models(Base):
pass
@dataclass(frozen=True)
class Optimizers(Base):
pass
@dataclass(frozen=True)
class Schedulers(Base):
pass
@dataclass(frozen=True)
class Extras(Base):
pass
# ---------------------------------------
info: Info
config: Config
# FSDP stuff
fsdp_defaults = {
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP,
"cpu_offload": None,
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
"limit_all_gathers": True,
}
fsdp_fullstate_save_policy = FullStateDictConfig(
offload_to_cpu=True, rank0_only=True
)
# ------------
# OVERRIDEABLE METHODS
# [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup
def setup_extras_pre(self) -> Extras:
return self.Extras()
# setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator
@abstractmethod
def setup_data(self, extras: Extras) -> Data:
raise NotImplementedError("This method needs to be overriden")
# return a dict with all models that are going to be used in the training
@abstractmethod
def setup_models(self, extras: Extras) -> Models:
raise NotImplementedError("This method needs to be overriden")
# return a dict with all optimizers that are going to be used in the training
@abstractmethod
def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
raise NotImplementedError("This method needs to be overriden")
# [optionally] return a dict with all schedulers that are going to be used in the training
def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers:
return self.Schedulers()
# [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup
def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras:
return self.Extras.from_dict(extras.to_dict())
# perform the training here
@abstractmethod
def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
raise NotImplementedError("This method needs to be overriden")
# ------------
def setup_info(self, full_path=None) -> Info:
if full_path is None:
full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json")
info_dict = load_or_fail(full_path, wandb_run_id=None) or {}
info_dto = self.Info(**info_dict)
if info_dto.total_steps > 0 and self.is_main_node:
print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps)
return info_dto
def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config:
if config_file_path is not None:
if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"):
with open(config_file_path, "r", encoding="utf-8") as file:
loaded_config = yaml.safe_load(file)
elif config_file_path.endswith(".json"):
with open(config_file_path, "r", encoding="utf-8") as file:
loaded_config = json.load(file)
else:
raise ValueError("Config file must be either a .yml|.yaml or .json file")
return self.Config.from_dict({**loaded_config, 'training': training})
if config_dict is not None:
return self.Config.from_dict({**config_dict, 'training': training})
return self.Config(training=training)
def setup_ddp(self, experiment_id, single_gpu=False):
if not single_gpu:
local_rank = int(os.environ.get("SLURM_LOCALID"))
process_id = int(os.environ.get("SLURM_PROCID"))
world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count()
self.process_id = process_id
self.is_main_node = process_id == 0
self.device = torch.device(local_rank)
self.world_size = world_size
dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}"
# if os.path.exists(dist_file_path) and self.is_main_node:
# os.remove(dist_file_path)
torch.cuda.set_device(local_rank)
init_process_group(
backend="nccl",
rank=process_id,
world_size=world_size,
init_method=f"file://{dist_file_path}",
)
print(f"[GPU {process_id}] READY")
else:
print("Running in single thread, DDP not enabled.")
def setup_wandb(self):
if self.is_main_node and self.config.wandb_project is not None:
self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id()
wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict())
if self.info.total_steps > 0:
wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}")
else:
wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started")
# LOAD UTILITIES ----------
def load_model(self, model, model_id=None, full_path=None, strict=True):
print('in line 181 load model', type(model), model_id, full_path, strict)
if model_id is not None and full_path is None:
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
elif full_path is None and model_id is None:
raise ValueError(
"This method expects either 'model_id' or 'full_path' to be defined"
)
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
if checkpoint is not None:
model.load_state_dict(checkpoint, strict=strict)
del checkpoint
return model
def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
if optim_id is not None and full_path is None:
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
elif full_path is None and optim_id is None:
raise ValueError(
"This method expects either 'optim_id' or 'full_path' to be defined"
)
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None)
if checkpoint is not None:
try:
if fsdp_model is not None:
sharded_optimizer_state_dict = (
FSDP.scatter_full_optim_state_dict( # <---- FSDP
checkpoint
if (
self.is_main_node
or self.fsdp_defaults["sharding_strategy"]
== ShardingStrategy.NO_SHARD
)
else None,
fsdp_model,
)
)
optim.load_state_dict(sharded_optimizer_state_dict)
del checkpoint, sharded_optimizer_state_dict
else:
optim.load_state_dict(checkpoint)
# pylint: disable=broad-except
except Exception as e:
print("!!! Failed loading optimizer, skipping... Exception:", e)
return optim
# SAVE UTILITIES ----------
def save_info(self, info, suffix=""):
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json"
create_folder_if_necessary(full_path)
if self.is_main_node:
safe_save(vars(self.info), full_path)
def save_model(self, model, model_id=None, full_path=None, is_fsdp=False):
if model_id is not None and full_path is None:
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}"
elif full_path is None and model_id is None:
raise ValueError(
"This method expects either 'model_id' or 'full_path' to be defined"
)
create_folder_if_necessary(full_path)
if is_fsdp:
with FSDP.summon_full_params(model):
pass
with FSDP.state_dict_type(
model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy
):
checkpoint = model.state_dict()
if self.is_main_node:
safe_save(checkpoint, full_path)
del checkpoint
else:
if self.is_main_node:
checkpoint = model.state_dict()
safe_save(checkpoint, full_path)
del checkpoint
def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None):
if optim_id is not None and full_path is None:
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt"
elif full_path is None and optim_id is None:
raise ValueError(
"This method expects either 'optim_id' or 'full_path' to be defined"
)
create_folder_if_necessary(full_path)
if fsdp_model is not None:
optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim)
if self.is_main_node:
safe_save(optim_statedict, full_path)
del optim_statedict
else:
if self.is_main_node:
checkpoint = optim.state_dict()
safe_save(checkpoint, full_path)
del checkpoint
# -----
def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True):
# Temporary setup, will be overriden by setup_ddp if required
self.device = device
self.process_id = 0
self.is_main_node = True
self.world_size = 1
# ----
self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
self.info: self.Info = self.setup_info()
def __call__(self, single_gpu=False):
self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank
self.setup_wandb()
if self.config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if self.is_main_node:
print()
print("**STARTIG JOB WITH CONFIG:**")
print(yaml.dump(self.config.to_dict(), default_flow_style=False))
print("------------------------------------")
print()
print("**INFO:**")
print(yaml.dump(vars(self.info), default_flow_style=False))
print("------------------------------------")
print()
# SETUP STUFF
extras = self.setup_extras_pre()
assert extras is not None, "setup_extras_pre() must return a DTO"
data = self.setup_data(extras)
assert data is not None, "setup_data() must return a DTO"
if self.is_main_node:
print("**DATA:**")
print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
print("------------------------------------")
print()
models = self.setup_models(extras)
assert models is not None, "setup_models() must return a DTO"
if self.is_main_node:
print("**MODELS:**")
print(yaml.dump({
k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
}, default_flow_style=False))
print("------------------------------------")
print()
optimizers = self.setup_optimizers(extras, models)
assert optimizers is not None, "setup_optimizers() must return a DTO"
if self.is_main_node:
print("**OPTIMIZERS:**")
print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
print("------------------------------------")
print()
schedulers = self.setup_schedulers(extras, models, optimizers)
assert schedulers is not None, "setup_schedulers() must return a DTO"
if self.is_main_node:
print("**SCHEDULERS:**")
print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
print("------------------------------------")
print()
post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
assert post_extras is not None, "setup_extras_post() must return a DTO"
extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
if self.is_main_node:
print("**EXTRAS:**")
print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
print("------------------------------------")
print()
# -------
# TRAIN
if self.is_main_node:
print("**TRAINING STARTING...**")
self.train(data, extras, models, optimizers, schedulers)
if single_gpu is False:
barrier()
destroy_process_group()
if self.is_main_node:
print()
print("------------------------------------")
print()
print("**TRAINING COMPLETE**")
if self.config.wandb_project is not None:
wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished")
|