Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from enum import Enum | |
import logging | |
from pathlib import Path | |
import re | |
import typing as tp | |
import flashy | |
import torch | |
from ..environment import AudioCraftEnvironment | |
logger = logging.getLogger(__name__) | |
class CheckpointSource(Enum): | |
CURRENT_XP = "current_xp" | |
PRETRAINED = "pretrained" | |
OTHER = "other" | |
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str: | |
"""Checkpoint name formatted for all use in AudioCraft codebase and has the following format: | |
`checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint, | |
'best' for the best checkpoint or the epoch number. | |
Args: | |
name (str, optional): Name suffix for the checkpoint file stem. | |
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. | |
use_fsdp (bool): Whether the calling solver relies on FSDP. | |
Returns: | |
str: The checkpoint name. | |
""" | |
suffix = '' | |
if rank is None: | |
rank = flashy.distrib.rank() | |
if rank > 0 and use_fsdp: | |
suffix = '.' + str(rank) | |
name_part = '' | |
if name is not None: | |
name_part = f'_{name}' | |
return f'checkpoint{name_part}.th{suffix}' | |
def is_sharded_checkpoint(path: Path) -> bool: | |
"""Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank.""" | |
return re.search(r'\.th\.\d+$', path.name) is not None | |
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None, | |
use_fsdp: bool = False) -> tp.Optional[Path]: | |
"""Resolve a given checkpoint path for a provided dora sig or path. | |
Args: | |
sig_or_path (Path or str): Checkpoint path or dora signature. | |
name (str, optional): Name suffix for the checkpoint file stem. | |
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided. | |
use_fsdp (bool): Whether the calling solver relies on FSDP. | |
Returns: | |
Path, optional: Resolved checkpoint path, if it exists. | |
""" | |
from audiocraft import train | |
xps_root = train.main.dora.dir / 'xps' | |
sig_or_path = str(sig_or_path) | |
if sig_or_path.startswith('//sig/'): | |
sig = sig_or_path[len('//sig/'):] | |
path = xps_root / sig | |
else: | |
path = Path(sig_or_path) | |
path = AudioCraftEnvironment.resolve_reference_path(path) | |
if path.is_dir(): | |
path = path / checkpoint_name(name, use_fsdp=use_fsdp) | |
if path.exists(): | |
return path | |
else: | |
return None | |
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any: | |
"""Load state from checkpoints at the specified checkpoint path.""" | |
if is_sharded: | |
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False) | |
if rank0_checkpoint_path.exists(): | |
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path) | |
state = torch.load(checkpoint_path, 'cpu') | |
logger.info("Checkpoint loaded from %s", checkpoint_path) | |
return state | |
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: | |
"""Save state to disk to the specified checkpoint_path.""" | |
_safe_save_checkpoint(state, checkpoint_path, is_sharded) | |
logger.info("Checkpoint saved to %s", checkpoint_path) | |
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None: | |
"""Flush checkpoints to only keep last N checkpoints.""" | |
if keep_last is None or keep_last <= 0: | |
return | |
checkpoint_dir = checkpoint_path.parent | |
suffix = '' | |
if flashy.distrib.rank() > 0: | |
suffix = f'.{flashy.distrib.rank()}' | |
checkpoint_files_with_epoch = [] | |
for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'): | |
epoch_part = path.name.split('.', 1)[0].split('_', 1)[1] | |
if epoch_part.isdigit(): | |
checkpoint_files_with_epoch.append((path, int(epoch_part))) | |
checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))] | |
total_to_flush = max(0, len(checkpoint_files) - keep_last) | |
files_to_flush = checkpoint_files[:total_to_flush] | |
for path in files_to_flush: | |
logger.debug("Removing checkpoint: %s", str(path)) | |
path.unlink(missing_ok=True) | |
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None: | |
"""Check sharded checkpoint state, ensuring the checkpoints are not corrupted.""" | |
# Finish the work of a previous run that got interrupted while dumping. | |
old_path = Path(str(checkpoint_path) + '.old') | |
if old_path.exists(): | |
raise RuntimeError( | |
f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.") | |
token = Path(str(rank0_checkpoint_path) + '.tmp.done') | |
tmp_path = Path(str(checkpoint_path) + '.tmp') | |
if token.exists(): | |
if tmp_path.exists(): | |
tmp_path.rename(checkpoint_path) | |
flashy.distrib.barrier() | |
if flashy.distrib.is_rank_zero() and token.exists(): | |
token.unlink() | |
def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None: | |
"""Save checkpoints in a safe manner even with when sharded checkpoints across nodes.""" | |
def _barrier_if_sharded(): | |
if is_sharded: | |
flashy.distrib.barrier() | |
if flashy.distrib.is_rank_zero(): | |
token = Path(str(checkpoint_path) + '.tmp.done') | |
if token.exists(): | |
token.unlink() | |
_barrier_if_sharded() | |
with flashy.utils.write_and_rename(checkpoint_path) as f: | |
torch.save(state, f) | |
_barrier_if_sharded() | |
if flashy.distrib.is_rank_zero(): | |
token.touch() | |
_barrier_if_sharded() | |
_barrier_if_sharded() | |
if flashy.distrib.rank() == 0: | |
token.unlink() | |