Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import yaml | |
import torch | |
from torch import nn | |
import wandb | |
import json | |
from abc import ABC, abstractmethod | |
from dataclasses import dataclass | |
from torch.utils.data import Dataset, DataLoader | |
from torch.distributed import init_process_group, destroy_process_group, barrier | |
from torch.distributed.fsdp import ( | |
FullyShardedDataParallel as FSDP, | |
FullStateDictConfig, | |
MixedPrecision, | |
ShardingStrategy, | |
StateDictType | |
) | |
from .utils import Base, EXPECTED, EXPECTED_TRAIN | |
from .utils import create_folder_if_necessary, safe_save, load_or_fail | |
# pylint: disable=unused-argument | |
class WarpCore(ABC): | |
class Config(Base): | |
experiment_id: str = EXPECTED_TRAIN | |
checkpoint_path: str = EXPECTED_TRAIN | |
output_path: str = EXPECTED_TRAIN | |
checkpoint_extension: str = "safetensors" | |
dist_file_subfolder: str = "" | |
allow_tf32: bool = True | |
wandb_project: str = None | |
wandb_entity: str = None | |
# not frozen, means that fields are mutable | |
class Info(): # not inheriting from Base, because we don't want to enforce the default fields | |
wandb_run_id: str = None | |
total_steps: int = 0 | |
iter: int = 0 | |
class Data(Base): | |
dataset: Dataset = EXPECTED | |
dataloader: DataLoader = EXPECTED | |
iterator: any = EXPECTED | |
class Models(Base): | |
pass | |
class Optimizers(Base): | |
pass | |
class Schedulers(Base): | |
pass | |
class Extras(Base): | |
pass | |
# --------------------------------------- | |
info: Info | |
config: Config | |
# FSDP stuff | |
fsdp_defaults = { | |
"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP, | |
"cpu_offload": None, | |
"mixed_precision": MixedPrecision( | |
param_dtype=torch.bfloat16, | |
reduce_dtype=torch.bfloat16, | |
buffer_dtype=torch.bfloat16, | |
), | |
"limit_all_gathers": True, | |
} | |
fsdp_fullstate_save_policy = FullStateDictConfig( | |
offload_to_cpu=True, rank0_only=True | |
) | |
# ------------ | |
# OVERRIDEABLE METHODS | |
# [optionally] setup extra stuff, will be called BEFORE the models & optimizers are setup | |
def setup_extras_pre(self) -> Extras: | |
return self.Extras() | |
# setup dataset & dataloader, return a dict contained dataser, dataloader and/or iterator | |
def setup_data(self, extras: Extras) -> Data: | |
raise NotImplementedError("This method needs to be overriden") | |
# return a dict with all models that are going to be used in the training | |
def setup_models(self, extras: Extras) -> Models: | |
raise NotImplementedError("This method needs to be overriden") | |
# return a dict with all optimizers that are going to be used in the training | |
def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers: | |
raise NotImplementedError("This method needs to be overriden") | |
# [optionally] return a dict with all schedulers that are going to be used in the training | |
def setup_schedulers(self, extras: Extras, models: Models, optimizers: Optimizers) -> Schedulers: | |
return self.Schedulers() | |
# [optionally] setup extra stuff, will be called AFTER the models & optimizers are setup | |
def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers) -> Extras: | |
return self.Extras.from_dict(extras.to_dict()) | |
# perform the training here | |
def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): | |
raise NotImplementedError("This method needs to be overriden") | |
# ------------ | |
def setup_info(self, full_path=None) -> Info: | |
if full_path is None: | |
full_path = (f"{self.config.checkpoint_path}/{self.config.experiment_id}/info.json") | |
info_dict = load_or_fail(full_path, wandb_run_id=None) or {} | |
info_dto = self.Info(**info_dict) | |
if info_dto.total_steps > 0 and self.is_main_node: | |
print(">>> RESUMING TRAINING FROM ITER ", info_dto.total_steps) | |
return info_dto | |
def setup_config(self, config_file_path=None, config_dict=None, training=True) -> Config: | |
if config_file_path is not None: | |
if config_file_path.endswith(".yml") or config_file_path.endswith(".yaml"): | |
with open(config_file_path, "r", encoding="utf-8") as file: | |
loaded_config = yaml.safe_load(file) | |
elif config_file_path.endswith(".json"): | |
with open(config_file_path, "r", encoding="utf-8") as file: | |
loaded_config = json.load(file) | |
else: | |
raise ValueError("Config file must be either a .yml|.yaml or .json file") | |
return self.Config.from_dict({**loaded_config, 'training': training}) | |
if config_dict is not None: | |
return self.Config.from_dict({**config_dict, 'training': training}) | |
return self.Config(training=training) | |
def setup_ddp(self, experiment_id, single_gpu=False): | |
if not single_gpu: | |
local_rank = int(os.environ.get("SLURM_LOCALID")) | |
process_id = int(os.environ.get("SLURM_PROCID")) | |
world_size = int(os.environ.get("SLURM_NNODES")) * torch.cuda.device_count() | |
self.process_id = process_id | |
self.is_main_node = process_id == 0 | |
self.device = torch.device(local_rank) | |
self.world_size = world_size | |
dist_file_path = f"{os.getcwd()}/{self.config.dist_file_subfolder}dist_file_{experiment_id}" | |
# if os.path.exists(dist_file_path) and self.is_main_node: | |
# os.remove(dist_file_path) | |
torch.cuda.set_device(local_rank) | |
init_process_group( | |
backend="nccl", | |
rank=process_id, | |
world_size=world_size, | |
init_method=f"file://{dist_file_path}", | |
) | |
print(f"[GPU {process_id}] READY") | |
else: | |
print("Running in single thread, DDP not enabled.") | |
def setup_wandb(self): | |
if self.is_main_node and self.config.wandb_project is not None: | |
self.info.wandb_run_id = self.info.wandb_run_id or wandb.util.generate_id() | |
wandb.init(project=self.config.wandb_project, entity=self.config.wandb_entity, name=self.config.experiment_id, id=self.info.wandb_run_id, resume="allow", config=self.config.to_dict()) | |
if self.info.total_steps > 0: | |
wandb.alert(title=f"Training {self.info.wandb_run_id} resumed", text=f"Training {self.info.wandb_run_id} resumed from step {self.info.total_steps}") | |
else: | |
wandb.alert(title=f"Training {self.info.wandb_run_id} started", text=f"Training {self.info.wandb_run_id} started") | |
# LOAD UTILITIES ---------- | |
def load_model(self, model, model_id=None, full_path=None, strict=True): | |
print('in line 181 load model', type(model), model_id, full_path, strict) | |
if model_id is not None and full_path is None: | |
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" | |
elif full_path is None and model_id is None: | |
raise ValueError( | |
"This method expects either 'model_id' or 'full_path' to be defined" | |
) | |
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) | |
if checkpoint is not None: | |
model.load_state_dict(checkpoint, strict=strict) | |
del checkpoint | |
return model | |
def load_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): | |
if optim_id is not None and full_path is None: | |
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" | |
elif full_path is None and optim_id is None: | |
raise ValueError( | |
"This method expects either 'optim_id' or 'full_path' to be defined" | |
) | |
checkpoint = load_or_fail(full_path, wandb_run_id=self.info.wandb_run_id if self.is_main_node else None) | |
if checkpoint is not None: | |
try: | |
if fsdp_model is not None: | |
sharded_optimizer_state_dict = ( | |
FSDP.scatter_full_optim_state_dict( # <---- FSDP | |
checkpoint | |
if ( | |
self.is_main_node | |
or self.fsdp_defaults["sharding_strategy"] | |
== ShardingStrategy.NO_SHARD | |
) | |
else None, | |
fsdp_model, | |
) | |
) | |
optim.load_state_dict(sharded_optimizer_state_dict) | |
del checkpoint, sharded_optimizer_state_dict | |
else: | |
optim.load_state_dict(checkpoint) | |
# pylint: disable=broad-except | |
except Exception as e: | |
print("!!! Failed loading optimizer, skipping... Exception:", e) | |
return optim | |
# SAVE UTILITIES ---------- | |
def save_info(self, info, suffix=""): | |
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/info{suffix}.json" | |
create_folder_if_necessary(full_path) | |
if self.is_main_node: | |
safe_save(vars(self.info), full_path) | |
def save_model(self, model, model_id=None, full_path=None, is_fsdp=False): | |
if model_id is not None and full_path is None: | |
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{model_id}.{self.config.checkpoint_extension}" | |
elif full_path is None and model_id is None: | |
raise ValueError( | |
"This method expects either 'model_id' or 'full_path' to be defined" | |
) | |
create_folder_if_necessary(full_path) | |
if is_fsdp: | |
with FSDP.summon_full_params(model): | |
pass | |
with FSDP.state_dict_type( | |
model, StateDictType.FULL_STATE_DICT, self.fsdp_fullstate_save_policy | |
): | |
checkpoint = model.state_dict() | |
if self.is_main_node: | |
safe_save(checkpoint, full_path) | |
del checkpoint | |
else: | |
if self.is_main_node: | |
checkpoint = model.state_dict() | |
safe_save(checkpoint, full_path) | |
del checkpoint | |
def save_optimizer(self, optim, optim_id=None, full_path=None, fsdp_model=None): | |
if optim_id is not None and full_path is None: | |
full_path = f"{self.config.checkpoint_path}/{self.config.experiment_id}/{optim_id}.pt" | |
elif full_path is None and optim_id is None: | |
raise ValueError( | |
"This method expects either 'optim_id' or 'full_path' to be defined" | |
) | |
create_folder_if_necessary(full_path) | |
if fsdp_model is not None: | |
optim_statedict = FSDP.full_optim_state_dict(fsdp_model, optim) | |
if self.is_main_node: | |
safe_save(optim_statedict, full_path) | |
del optim_statedict | |
else: | |
if self.is_main_node: | |
checkpoint = optim.state_dict() | |
safe_save(checkpoint, full_path) | |
del checkpoint | |
# ----- | |
def __init__(self, config_file_path=None, config_dict=None, device="cpu", training=True): | |
# Temporary setup, will be overriden by setup_ddp if required | |
self.device = device | |
self.process_id = 0 | |
self.is_main_node = True | |
self.world_size = 1 | |
# ---- | |
self.config: self.Config = self.setup_config(config_file_path, config_dict, training) | |
self.info: self.Info = self.setup_info() | |
def __call__(self, single_gpu=False): | |
self.setup_ddp(self.config.experiment_id, single_gpu=single_gpu) # this will change the device to the CUDA rank | |
self.setup_wandb() | |
if self.config.allow_tf32: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
if self.is_main_node: | |
print() | |
print("**STARTIG JOB WITH CONFIG:**") | |
print(yaml.dump(self.config.to_dict(), default_flow_style=False)) | |
print("------------------------------------") | |
print() | |
print("**INFO:**") | |
print(yaml.dump(vars(self.info), default_flow_style=False)) | |
print("------------------------------------") | |
print() | |
# SETUP STUFF | |
extras = self.setup_extras_pre() | |
assert extras is not None, "setup_extras_pre() must return a DTO" | |
data = self.setup_data(extras) | |
assert data is not None, "setup_data() must return a DTO" | |
if self.is_main_node: | |
print("**DATA:**") | |
print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False)) | |
print("------------------------------------") | |
print() | |
models = self.setup_models(extras) | |
assert models is not None, "setup_models() must return a DTO" | |
if self.is_main_node: | |
print("**MODELS:**") | |
print(yaml.dump({ | |
k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items() | |
}, default_flow_style=False)) | |
print("------------------------------------") | |
print() | |
optimizers = self.setup_optimizers(extras, models) | |
assert optimizers is not None, "setup_optimizers() must return a DTO" | |
if self.is_main_node: | |
print("**OPTIMIZERS:**") | |
print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False)) | |
print("------------------------------------") | |
print() | |
schedulers = self.setup_schedulers(extras, models, optimizers) | |
assert schedulers is not None, "setup_schedulers() must return a DTO" | |
if self.is_main_node: | |
print("**SCHEDULERS:**") | |
print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False)) | |
print("------------------------------------") | |
print() | |
post_extras =self.setup_extras_post(extras, models, optimizers, schedulers) | |
assert post_extras is not None, "setup_extras_post() must return a DTO" | |
extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() }) | |
if self.is_main_node: | |
print("**EXTRAS:**") | |
print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False)) | |
print("------------------------------------") | |
print() | |
# ------- | |
# TRAIN | |
if self.is_main_node: | |
print("**TRAINING STARTING...**") | |
self.train(data, extras, models, optimizers, schedulers) | |
if single_gpu is False: | |
barrier() | |
destroy_process_group() | |
if self.is_main_node: | |
print() | |
print("------------------------------------") | |
print() | |
print("**TRAINING COMPLETE**") | |
if self.config.wandb_project is not None: | |
wandb.alert(title=f"Training {self.info.wandb_run_id} finished", text=f"Training {self.info.wandb_run_id} finished") | |