File size: 1,942 Bytes
5231633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
import torch
import json
from pathlib import Path
import safetensors
import wandb
def create_folder_if_necessary(path):
path = "/".join(path.split("/")[:-1])
Path(path).mkdir(parents=True, exist_ok=True)
def safe_save(ckpt, path):
try:
os.remove(f"{path}.bak")
except OSError:
pass
try:
os.rename(path, f"{path}.bak")
except OSError:
pass
if path.endswith(".pt") or path.endswith(".ckpt"):
torch.save(ckpt, path)
elif path.endswith(".json"):
with open(path, "w", encoding="utf-8") as f:
json.dump(ckpt, f, indent=4)
elif path.endswith(".safetensors"):
safetensors.torch.save_file(ckpt, path)
else:
raise ValueError(f"File extension not supported: {path}")
def load_or_fail(path, wandb_run_id=None):
accepted_extensions = [".pt", ".ckpt", ".json", ".safetensors"]
try:
assert any(
[path.endswith(ext) for ext in accepted_extensions]
), f"Automatic loading not supported for this extension: {path}"
if not os.path.exists(path):
checkpoint = None
elif path.endswith(".pt") or path.endswith(".ckpt"):
checkpoint = torch.load(path, map_location="cpu")
elif path.endswith(".json"):
with open(path, "r", encoding="utf-8") as f:
checkpoint = json.load(f)
elif path.endswith(".safetensors"):
checkpoint = {}
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
return checkpoint
except Exception as e:
if wandb_run_id is not None:
wandb.alert(
title=f"Corrupt checkpoint for run {wandb_run_id}",
text=f"Training {wandb_run_id} tried to load checkpoint {path} and failed",
)
raise e
|