Spaces:
Sleeping
Sleeping
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Apache License, Version 2.0 | |
# found in the LICENSE file in the root directory of this source tree. | |
import os | |
from typing import Any | |
import torch | |
import dinov2.distributed as distributed | |
from functools import partial | |
from fvcore.common.checkpoint import Checkpointer | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from torch.distributed.fsdp import ShardingStrategy | |
from torch.distributed.fsdp import MixedPrecision | |
from torch.distributed.fsdp import StateDictType | |
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler | |
from torch.distributed.fsdp.wrap import ModuleWrapPolicy | |
from torch.distributed.fsdp._runtime_utils import _reshard | |
def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): | |
sharding_strategy_dict = { | |
"NO_SHARD": ShardingStrategy.NO_SHARD, | |
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, | |
"FULL_SHARD": ShardingStrategy.FULL_SHARD, | |
} | |
dtype_dict = { | |
"fp32": torch.float32, | |
"fp16": torch.float16, | |
"bf16": torch.bfloat16, | |
} | |
mixed_precision_config = MixedPrecision( | |
param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], | |
reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], | |
buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], | |
) | |
sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] | |
local_rank = distributed.get_local_rank() | |
fsdp_wrapper = partial( | |
FSDP, | |
sharding_strategy=sharding_strategy_config, | |
mixed_precision=mixed_precision_config, | |
device_id=local_rank, | |
sync_module_states=True, | |
use_orig_params=True, | |
auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), | |
) | |
return fsdp_wrapper | |
def is_fsdp(x): | |
return isinstance(x, FSDP) | |
def is_sharded_fsdp(x): | |
return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD | |
def free_if_fsdp(x): | |
if is_sharded_fsdp(x): | |
handles = x._handles | |
true_list = [True for h in handles] | |
_reshard(x, handles, true_list) | |
def get_fsdp_modules(x): | |
return FSDP.fsdp_modules(x) | |
def reshard_fsdp_model(x): | |
for m in get_fsdp_modules(x): | |
free_if_fsdp(m) | |
def rankstr(): | |
return f"rank_{distributed.get_global_rank()}" | |
class FSDPCheckpointer(Checkpointer): | |
def save(self, name: str, **kwargs: Any) -> None: | |
""" | |
Dump model and checkpointables to a file. | |
Args: | |
name (str): name of the file. | |
kwargs (dict): extra arbitrary data to save. | |
""" | |
if not self.save_dir or not self.save_to_disk: | |
return | |
data = {} | |
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): | |
data["model"] = self.model.state_dict() | |
# data["model"] = self.model.state_dict() | |
for key, obj in self.checkpointables.items(): | |
data[key] = obj.state_dict() | |
data.update(kwargs) | |
basename = f"{name}.{rankstr()}.pth" | |
save_file = os.path.join(self.save_dir, basename) | |
assert os.path.basename(save_file) == basename, basename | |
self.logger.info("Saving checkpoint to {}".format(save_file)) | |
with self.path_manager.open(save_file, "wb") as f: | |
torch.save(data, f) | |
self.tag_last_checkpoint(basename) | |
def load(self, *args, **kwargs): | |
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): | |
return super().load(*args, **kwargs) | |
def has_checkpoint(self) -> bool: | |
""" | |
Returns: | |
bool: whether a checkpoint exists in the target directory. | |
""" | |
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") | |
return self.path_manager.exists(save_file) | |
def get_checkpoint_file(self) -> str: | |
""" | |
Returns: | |
str: The latest checkpoint file in target directory. | |
""" | |
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") | |
try: | |
with self.path_manager.open(save_file, "r") as f: | |
last_saved = f.read().strip() | |
except IOError: | |
# if file doesn't exist, maybe because it has just been | |
# deleted by a separate process | |
return "" | |
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got | |
# `Union[bytes, str]`. | |
return os.path.join(self.save_dir, last_saved) | |
def tag_last_checkpoint(self, last_filename_basename: str) -> None: | |
""" | |
Tag the last checkpoint. | |
Args: | |
last_filename_basename (str): the basename of the last filename. | |
""" | |
if distributed.is_enabled(): | |
torch.distributed.barrier() | |
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") | |
with self.path_manager.open(save_file, "w") as f: | |
f.write(last_filename_basename) # pyre-ignore | |
ShardedGradScaler = ShardedGradScaler | |