satlas / convert.py
ajstewart's picture
Enable for all files
081d660 unverified
#!/usr/bin/env python3
"""Convert Satlas-Pretrain model checkpoints to a format accepted by TorchGeo.
Reference implementation:
* https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py
"""
import glob
import hashlib
import os
import timm
import torch
import torchvision
for checkpoint in glob.iglob('*.pth'):
# Skip if already converted
if '-' in checkpoint:
continue
print(checkpoint)
# Map to CPU
state_dict = torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=True)
# Extract backbone
if 'backbone.backbone.resnet.conv1.weight' in state_dict:
state_dict = {key.replace('backbone.backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.resnet.')}
elif 'backbone.resnet.conv1.weight' in state_dict:
state_dict = {key.replace('backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.resnet.')}
elif 'backbone.backbone.backbone.features.0.0.weight' in state_dict:
state_dict = {key.replace('backbone.backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.backbone.')}
elif 'backbone.backbone.features.0.0.weight' in state_dict:
state_dict = {key.replace('backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.')}
if 'resnet' in checkpoint:
# Extract # channels
in_chans = state_dict['conv1.weight'].shape[1]
# Create model
model_name = checkpoint.split('_')[1]
model = timm.create_model(model_name, in_chans=in_chans)
elif 'swin' in checkpoint:
# Extract # channels
out_channels, num_channels, kernel_size_0, kernel_size_1 = state_dict['features.0.0.weight'].shape
# Create model
if 'swint' in checkpoint:
model = torchvision.models.swin_v2_t()
elif 'swinb' in checkpoint:
model = torchvision.models.swin_v2_b()
model.features[0][0] = torch.nn.Conv2d(num_channels, out_channels, kernel_size=(kernel_size_0, kernel_size_1), stride=(4, 4))
# Load weights
model.load_state_dict(state_dict)
# Save model
torch.save(model.state_dict(), f'{checkpoint}.tmp')
# Compute the checksum
with open(f'{checkpoint}.tmp', 'rb') as f:
checksum = hashlib.file_digest(f, 'sha256').hexdigest()
# Rename
os.rename(f'{checkpoint}.tmp', f'{checkpoint[:-4]}-{checksum[:8]}.pth')