File size: 1,020 Bytes
cff1674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import argparse
import torch
from tqdm import tqdm

parser = argparse.ArgumentParser(description="Prune a model")
parser.add_argument("model_prune", type=str, help="Path to model to prune")
parser.add_argument("prune_output", type=str, help="Path to pruned ckpt output")
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
args = parser.parse_args()

print("Loading model...")
model_prune = torch.load(args.model_prune)
theta_prune = model_prune["state_dict"]
theta = {}

print("Pruning model...")
for key in tqdm(theta_prune.keys(), desc="Pruning keys"):
    if "model" in key:
        theta.update({key: theta_prune[key]})

del theta_prune

if args.half:
    print("Halving model...")
    state_dict = {k: v.half() for k, v in tqdm(theta.items(), desc="Halving weights")}
else:
    state_dict = theta

del theta

print("Saving pruned model...")

torch.save({"state_dict": state_dict}, args.prune_output)

del state_dict

print("Done pruning!")