|
|
|
|
|
"""Convert Satlas-Pretrain model checkpoints to a format accepted by TorchGeo. |
|
|
|
Reference implementation: |
|
|
|
* https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py |
|
""" |
|
|
|
import glob |
|
import hashlib |
|
import os |
|
|
|
import timm |
|
import torch |
|
import torchvision |
|
|
|
|
|
for checkpoint in glob.iglob('*.pth'): |
|
|
|
if '-' in checkpoint: |
|
continue |
|
|
|
print(checkpoint) |
|
|
|
|
|
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=True) |
|
|
|
|
|
if 'backbone.backbone.resnet.conv1.weight' in state_dict: |
|
state_dict = {key.replace('backbone.backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.resnet.')} |
|
elif 'backbone.resnet.conv1.weight' in state_dict: |
|
state_dict = {key.replace('backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.resnet.')} |
|
elif 'backbone.backbone.backbone.features.0.0.weight' in state_dict: |
|
state_dict = {key.replace('backbone.backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.backbone.')} |
|
elif 'backbone.backbone.features.0.0.weight' in state_dict: |
|
state_dict = {key.replace('backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.')} |
|
|
|
if 'resnet' in checkpoint: |
|
|
|
in_chans = state_dict['conv1.weight'].shape[1] |
|
|
|
|
|
model_name = checkpoint.split('_')[1] |
|
model = timm.create_model(model_name, in_chans=in_chans) |
|
elif 'swin' in checkpoint: |
|
|
|
out_channels, num_channels, kernel_size_0, kernel_size_1 = state_dict['features.0.0.weight'].shape |
|
|
|
|
|
if 'swint' in checkpoint: |
|
model = torchvision.models.swin_v2_t() |
|
elif 'swinb' in checkpoint: |
|
model = torchvision.models.swin_v2_b() |
|
|
|
model.features[0][0] = torch.nn.Conv2d(num_channels, out_channels, kernel_size=(kernel_size_0, kernel_size_1), stride=(4, 4)) |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
torch.save(model.state_dict(), f'{checkpoint}.tmp') |
|
|
|
|
|
with open(f'{checkpoint}.tmp', 'rb') as f: |
|
checksum = hashlib.file_digest(f, 'sha256').hexdigest() |
|
|
|
|
|
os.rename(f'{checkpoint}.tmp', f'{checkpoint[:-4]}-{checksum[:8]}.pth') |
|
|