File size: 1,073 Bytes
8acbfe7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import json
from typing import Dict

from safetensors.torch import load_file, save_file
from huggingface_hub import split_torch_state_dict_into_shards
import torch
import os

def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
    state_dict_split = split_torch_state_dict_into_shards(state_dict, filename_pattern='consolidated{suffix}.safetensors')
    for filename, tensors in state_dict_split.filename_to_tensors.items():
        shard = {tensor: state_dict[tensor] for tensor in tensors}
        print("Saving", save_directory, filename)
        save_file(shard, os.path.join(save_directory, filename))
    if state_dict_split.is_sharded:
        index = {
            "metadata": state_dict_split.metadata,
            "weight_map": state_dict_split.tensor_to_filename,
        }
        with open(os.path.join(save_directory, "consolidated.safetensors.index.json"), "w") as f:
            f.write(json.dumps(index, indent=2))

big_file = 'consolidated.safetensors'
loaded = load_file(big_file)

save_state_dict(loaded, save_directory=f'.')