|
|
|
|
|
"""Extract the model backbone from the checkpoint.""" |
|
|
|
import torch |
|
|
|
from torchgeo.models import dofa_base_patch16_224 |
|
|
|
|
|
in_filename = "ofa_base_checkpoint_e99.pth" |
|
checkpoint = torch.load(in_filename, map_location=torch.device("cpu")) |
|
|
|
|
|
weights = checkpoint["model"] |
|
del weights["mask_token"] |
|
del weights["norm.weight"], weights["norm.bias"] |
|
del weights["projector.weight"], weights["projector.bias"] |
|
|
|
|
|
|
|
allowed_missing_keys = {"fc_norm.weight", "fc_norm.bias", "head.weight", "head.bias"} |
|
model = dofa_base_patch16_224() |
|
missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) |
|
assert set(missing_keys) <= allowed_missing_keys |
|
assert not unexpected_keys |
|
|
|
|
|
|
|
out_filename = "dofa_base_patch16_224.pth" |
|
torch.save(weights, out_filename) |
|
|