File size: 972 Bytes
267900f 4c557dd 267900f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
#!/usr/bin/env python3
"""Extract the model backbone from the checkpoint."""
import torch
from torchgeo.models import dofa_base_patch16_224
# Load the checkpoint
in_filename = "DOFA_ViT_base_e100.pth"
weights = torch.load(in_filename, map_location=torch.device("cpu"))
# Remove extra keys
del weights["mask_token"]
del weights["norm.weight"], weights["norm.bias"]
del weights["projector.weight"], weights["projector.bias"]
# Load the weights to ensure they are valid
# fc_norm and head are generated dynamically
allowed_missing_keys = {"fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"}
model = dofa_base_patch16_224()
missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
assert set(missing_keys) <= allowed_missing_keys
assert not unexpected_keys
# Save the cleaned checkpoint
# Should be manually renamed later, add first 8 digits of sha256 to suffix
out_filename = "dofa_base_patch16_224.pth"
torch.save(weights, out_filename)
|