|
import itertools |
|
from typing import List, Optional, Tuple, Union |
|
import safetensors |
|
import torch |
|
from torch import Tensor |
|
import os |
|
from pathlib import Path |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
def get_parameter_device(parameter: torch.nn.Module): |
|
try: |
|
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) |
|
return next(parameters_and_buffers).device |
|
except StopIteration: |
|
|
|
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: |
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
|
return tuples |
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
|
first_tuple = next(gen) |
|
return first_tuple[1].device |
|
|
|
|
|
def get_parameter_dtype(parameter: torch.nn.Module): |
|
try: |
|
params = tuple(parameter.parameters()) |
|
if len(params) > 0: |
|
return params[0].dtype |
|
|
|
buffers = tuple(parameter.buffers()) |
|
if len(buffers) > 0: |
|
return buffers[0].dtype |
|
|
|
except StopIteration: |
|
|
|
|
|
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: |
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] |
|
return tuples |
|
|
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes) |
|
first_tuple = next(gen) |
|
return first_tuple[1].dtype |
|
|
|
|
|
def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path: |
|
path_obj = Path(save_path) |
|
return path_obj.parent |
|
|
|
def get_base_name(save_path: Union[str, os.PathLike]) -> str: |
|
path_obj = Path(save_path) |
|
return path_obj.name |
|
|
|
def load_state_dict_from_path(path: Union[str, os.PathLike]): |
|
|
|
if 'safetensors' in path: |
|
state_dict = safetensors.torch.load_file(path) |
|
else: |
|
state_dict = torch.load(path, map_location="cpu") |
|
return state_dict |
|
|
|
def replace_extension(path, new_extension): |
|
if not new_extension.startswith('.'): |
|
new_extension = '.' + new_extension |
|
return os.path.splitext(path)[0] + new_extension |
|
|
|
def make_config_path(save_path): |
|
config_path = replace_extension(save_path, '.yaml') |
|
return config_path |
|
|
|
def save_config(config, config_path): |
|
assert isinstance(config, dict) or isinstance(config, DictConfig) |
|
os.makedirs(get_parent_directory(config_path), exist_ok=True) |
|
if isinstance(config, dict): |
|
config = OmegaConf.create(config) |
|
OmegaConf.save(config, config_path) |
|
|
|
|
|
def save_state_dict_and_config(state_dict, config, save_path): |
|
os.makedirs(get_parent_directory(save_path), exist_ok=True) |
|
|
|
|
|
config_path = make_config_path(save_path) |
|
save_config(config, config_path) |
|
|
|
|
|
if 'safetensors' in save_path: |
|
safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"}) |
|
else: |
|
torch.save(state_dict, save_path) |
|
|