File size: 5,024 Bytes
f21d996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d7757
 
f21d996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d7757
f21d996
 
 
 
c6d7757
 
 
 
 
 
 
f21d996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d7757
f21d996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6d7757
f21d996
 
 
c6d7757
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import argparse
from safetensors import safe_open
from safetensors.torch import save_file
import json
from tqdm import tqdm

def get_tensor_locations(input_dir):
    tensor_locations = {}
    for i in tqdm(range(1, 52), desc="Scanning input files"):  # 51 splits
        file_path = os.path.join(input_dir, f"model-{i:05d}-of-00051.safetensors")
        with safe_open(file_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                tensor_locations[key] = i
    return tensor_locations

def create_merge_plan(tensor_locations, layer_config):
    merge_plan = []
    new_layer_idx = 0
    new_file_idx = 1

    # Special handling for specific weights
    special_weights = {
        "model.embed_tokens.weight": 1,
        "lm_head.weight": 156,
        "model.norm.weight": 156
    }

    for slice_config in layer_config:
        start, end = slice_config['layer_range']
        for i in range(start, end):
            layer_tensors = []
            for key in tensor_locations.keys():
                if key.startswith(f"model.layers.{i}."):
                    new_key = key.replace(f"model.layers.{i}", f"model.layers.{new_layer_idx}")
                    layer_tensors.append({
                        'old_key': key,
                        'new_key': new_key,
                        'original_file_index': tensor_locations[key],
                        'new_file_index': new_file_idx
                    })
            if layer_tensors:
                merge_plan.extend(layer_tensors)
                new_file_idx += 1
            new_layer_idx += 1
    
    # Add special weights to their original locations
    for key, file_index in special_weights.items():
        merge_plan.append({
            'old_key': key,
            'new_key': key,
            'original_file_index': file_index,
            'new_file_index': file_index
        })
    
    # Add any remaining non-layer tensors to the first file
    for key, file_index in tensor_locations.items():
        if not key.startswith("model.layers.") and key not in special_weights:
            merge_plan.append({
                'old_key': key,
                'new_key': key,
                'original_file_index': file_index,
                'new_file_index': 1
            })
    
    return merge_plan

def merge_layers(input_dir, output_dir, merge_plan, start_file_index=1):
    output_tensors = {}
    max_file_index = max(item['new_file_index'] for item in merge_plan)

    with tqdm(total=len(merge_plan), desc="Merging layers") as pbar:
        for file_index in range(start_file_index, max_file_index + 1):
            output_file = os.path.join(output_dir, f"model-{file_index:05d}-of-{max_file_index:05d}.safetensors")
            
            if os.path.exists(output_file):
                pbar.update(sum(1 for item in merge_plan if item['new_file_index'] == file_index))
                continue

            for item in merge_plan:
                if item['new_file_index'] == file_index:
                    input_file = os.path.join(input_dir, f"model-{item['original_file_index']:05d}-of-00051.safetensors")
                    with safe_open(input_file, framework="pt", device="cpu") as f:
                        tensor = f.get_tensor(item['old_key'])
                        output_tensors[item['new_key']] = tensor
                    pbar.update(1)

            if output_tensors:
                save_file(output_tensors, output_file)
                output_tensors = {}

    print(f"Merged model saved to {output_dir}")

def main():
    parser = argparse.ArgumentParser(description="Merge and split Mistral model")
    parser.add_argument("input_dir", help="Directory containing input safetensors files")
    parser.add_argument("output_dir", help="Directory for output safetensors files")
    parser.add_argument("--dry-run", action="store_true", help="Perform a dry run and output merge plan")
    parser.add_argument("--continue-from", type=int, default=1, help="Continue merging from this file index")
    args = parser.parse_args()

    layer_config = [
        {'layer_range': [0, 20]},
        {'layer_range': [10, 30]},
        {'layer_range': [20, 40]},
        {'layer_range': [30, 50]},
        {'layer_range': [40, 60]},
        {'layer_range': [50, 70]},
        {'layer_range': [60, 80]},
        {'layer_range': [70, 87]}
    ]

    tensor_locations = get_tensor_locations(args.input_dir)
    merge_plan = create_merge_plan(tensor_locations, layer_config)

    if args.dry_run:
        print("Merge plan:")
        print(json.dumps(merge_plan, indent=2))
        with open("merge_plan_large.json", "w") as f:
            json.dump(merge_plan, f, indent=2)
        print("Merge plan saved to merge_plan.json")
    else:
        os.makedirs(args.output_dir, exist_ok=True)
        merge_layers(args.input_dir, args.output_dir, merge_plan, start_file_index=args.continue_from)
        print(f"Merged model saved to {args.output_dir}")

if __name__ == "__main__":
    main()