dofa / extract.py
ajstewart's picture
Upload 2 files
267900f verified
raw
history blame
1.01 kB
#!/usr/bin/env python3
"""Extract the model backbone from the checkpoint."""
import torch
from torchgeo.models import dofa_base_patch16_224
# Load the checkpoint
in_filename = "ofa_base_checkpoint_e99.pth"
checkpoint = torch.load(in_filename, map_location=torch.device("cpu"))
# Remove extra keys
weights = checkpoint["model"]
del weights["mask_token"]
del weights["norm.weight"], weights["norm.bias"]
del weights["projector.weight"], weights["projector.bias"]
# Load the weights to ensure they are valid
# fc_norm and head are generated dynamically
allowed_missing_keys = {"fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"}
model = dofa_base_patch16_224()
missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False)
assert set(missing_keys) <= allowed_missing_keys
assert not unexpected_keys
# Save the cleaned checkpoint
# Should be manually renamed later, add first 8 digits of sha256 to suffix
out_filename = "dofa_base_patch16_224.pth"
torch.save(weights, out_filename)