#!/usr/bin/env python3 """Convert Satlas-Pretrain model checkpoints to a format accepted by TorchGeo. Reference implementation: * https://github.com/allenai/satlaspretrain_models/blob/main/satlaspretrain_models/models/backbones.py """ import glob import hashlib import os import timm import torch import torchvision for checkpoint in glob.iglob('*.pth'): # Skip if already converted if '-' in checkpoint: continue print(checkpoint) # Map to CPU state_dict = torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=True) # Extract backbone if 'backbone.backbone.resnet.conv1.weight' in state_dict: state_dict = {key.replace('backbone.backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.resnet.')} elif 'backbone.resnet.conv1.weight' in state_dict: state_dict = {key.replace('backbone.resnet.', ''): value for key, value in state_dict.items() if key.startswith('backbone.resnet.')} elif 'backbone.backbone.backbone.features.0.0.weight' in state_dict: state_dict = {key.replace('backbone.backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.backbone.')} elif 'backbone.backbone.features.0.0.weight' in state_dict: state_dict = {key.replace('backbone.backbone.', ''): value for key, value in state_dict.items() if key.startswith('backbone.backbone.')} if 'resnet' in checkpoint: # Extract # channels in_chans = state_dict['conv1.weight'].shape[1] # Create model model_name = checkpoint.split('_')[1] model = timm.create_model(model_name, in_chans=in_chans) elif 'swin' in checkpoint: # Extract # channels out_channels, num_channels, kernel_size_0, kernel_size_1 = state_dict['features.0.0.weight'].shape # Create model if 'swint' in checkpoint: model = torchvision.models.swin_v2_t() elif 'swinb' in checkpoint: model = torchvision.models.swin_v2_b() model.features[0][0] = torch.nn.Conv2d(num_channels, out_channels, kernel_size=(kernel_size_0, kernel_size_1), stride=(4, 4)) # Load weights model.load_state_dict(state_dict) # Save model torch.save(model.state_dict(), f'{checkpoint}.tmp') # Compute the checksum with open(f'{checkpoint}.tmp', 'rb') as f: checksum = hashlib.file_digest(f, 'sha256').hexdigest() # Rename os.rename(f'{checkpoint}.tmp', f'{checkpoint[:-4]}-{checksum[:8]}.pth')