import os import torch import json from pathlib import Path import safetensors import wandb def create_folder_if_necessary(path): path = "/".join(path.split("/")[:-1]) Path(path).mkdir(parents=True, exist_ok=True) def safe_save(ckpt, path): try: os.remove(f"{path}.bak") except OSError: pass try: os.rename(path, f"{path}.bak") except OSError: pass if path.endswith(".pt") or path.endswith(".ckpt"): torch.save(ckpt, path) elif path.endswith(".json"): with open(path, "w", encoding="utf-8") as f: json.dump(ckpt, f, indent=4) elif path.endswith(".safetensors"): safetensors.torch.save_file(ckpt, path) else: raise ValueError(f"File extension not supported: {path}") def load_or_fail(path, wandb_run_id=None): accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"] try: assert any( [path.endswith(ext) for ext in accepted_extensions] ), f"Automatic loading not supported for this extension: {path}" if not os.path.exists(path): checkpoint = None elif path.endswith(".pt") or path.endswith(".ckpt"): checkpoint = torch.load(path, map_location="cpu") elif path.endswith(".json"): with open(path, "r", encoding="utf-8") as f: checkpoint = json.load(f) elif path.endswith(".safetensors"): checkpoint = {} with safetensors.safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): checkpoint[key] = f.get_tensor(key) return checkpoint except Exception as e: if wandb_run_id is not None: wandb.alert( title=f"Corrupt checkpoint for run {wandb_run_id}", text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed", ) raise e