File size: 1,010 Bytes
267900f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#!/usr/bin/env python3

"""Extract the model backbone from the checkpoint."""

import torch

from torchgeo.models import dofa_base_patch16_224

# Load the checkpoint
in_filename = "ofa_base_checkpoint_e99.pth"
checkpoint = torch.load(in_filename, map_location=torch.device("cpu"))

# Remove extra keys
weights = checkpoint["model"]
del weights["mask_token"]
del weights["norm.weight"], weights["norm.bias"]
del weights["projector.weight"], weights["projector.bias"]

# Load the weights to ensure they are valid
# fc_norm and head are generated dynamically
allowed_missing_keys = {"fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"}
model = dofa_base_patch16_224()
missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
assert set(missing_keys) <= allowed_missing_keys
assert not unexpected_keys

# Save the cleaned checkpoint
# Should be manually renamed later, add first 8 digits of sha256 to suffix
out_filename = "dofa_base_patch16_224.pth"
torch.save(weights, out_filename)