llmixer's picture
Model upload
8acbfe7 verified
raw
history blame
1.07 kB
import json
from typing import Dict
from safetensors.torch import load_file, save_file
from huggingface_hub import split_torch_state_dict_into_shards
import torch
import os
def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
state_dict_split = split_torch_state_dict_into_shards(state_dict, filename_pattern='consolidated{suffix}.safetensors')
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
print("Saving", save_directory, filename)
save_file(shard, os.path.join(save_directory, filename))
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
with open(os.path.join(save_directory, "consolidated.safetensors.index.json"), "w") as f:
f.write(json.dumps(index, indent=2))
big_file = 'consolidated.safetensors'
loaded = load_file(big_file)
save_state_dict(loaded, save_directory=f'.')