AdamOswald1's picture
Upload 3 files
eccbce0
raw
history blame
380 Bytes
import torch
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-I', type=str, help='Input file to prune', required = True)
args = parser.parse_args()
file = args.input
checkpoint = torch.load(file)
new_sd = dict()
for k in checkpoint.keys():
if k != 'optimizer_states':
new_sd[k] = checkpoint[k]
torch.save(new_sd, f'pruned-{file}')