minchul's picture
Upload directory
2fb0a42 verified
raw
history blame
3.12 kB
import itertools
from typing import List, Optional, Tuple, Union
import safetensors
import torch
from torch import Tensor
import os
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
def get_parameter_device(parameter: torch.nn.Module):
try:
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
return next(parameters_and_buffers).device
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
def get_parameter_dtype(parameter: torch.nn.Module):
try:
params = tuple(parameter.parameters())
if len(params) > 0:
return params[0].dtype
buffers = tuple(parameter.buffers())
if len(buffers) > 0:
return buffers[0].dtype
except StopIteration:
# For torch.nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
path_obj = Path(save_path)
return path_obj.parent
def get_base_name(save_path: Union[str, os.PathLike]) -> str:
path_obj = Path(save_path)
return path_obj.name
def load_state_dict_from_path(path: Union[str, os.PathLike]):
# Load a state dict from a path.
if 'safetensors' in path:
state_dict = safetensors.torch.load_file(path)
else:
state_dict = torch.load(path, map_location="cpu")
return state_dict
def replace_extension(path, new_extension):
if not new_extension.startswith('.'):
new_extension = '.' + new_extension
return os.path.splitext(path)[0] + new_extension
def make_config_path(save_path):
config_path = replace_extension(save_path, '.yaml')
return config_path
def save_config(config, config_path):
assert isinstance(config, dict) or isinstance(config, DictConfig)
os.makedirs(get_parent_directory(config_path), exist_ok=True)
if isinstance(config, dict):
config = OmegaConf.create(config)
OmegaConf.save(config, config_path)
def save_state_dict_and_config(state_dict, config, save_path):
os.makedirs(get_parent_directory(save_path), exist_ok=True)
# save config dict
config_path = make_config_path(save_path)
save_config(config, config_path)
# Save the model
if 'safetensors' in save_path:
safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
else:
torch.save(state_dict, save_path)