""" Model creation / weight loading / state_dict helpers Hacked together by / Copyright 2020 Ross Wightman """ import logging import os import math from collections import OrderedDict from copy import deepcopy from typing import Callable import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo _logger = logging.getLogger(__name__) def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = 'state_dict' if isinstance(checkpoint, dict): if use_ema and 'state_dict_ema' in checkpoint: state_dict_key = 'state_dict_ema' if state_dict_key and state_dict_key in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint[state_dict_key].items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v state_dict = new_state_dict else: state_dict = checkpoint _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) return state_dict else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if log_info: _logger.info('Restoring model state from checkpoint...') new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): name = k[7:] if k.startswith('module') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) if optimizer is not None and 'optimizer' in checkpoint: if log_info: _logger.info('Restoring optimizer state from checkpoint...') optimizer.load_state_dict(checkpoint['optimizer']) if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint: if log_info: _logger.info('Restoring AMP loss scaler state from checkpoint...') loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key]) if 'epoch' in checkpoint: resume_epoch = checkpoint['epoch'] if 'version' in checkpoint and checkpoint['version'] > 1: resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save if log_info: _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) if log_info: _logger.info("Loaded checkpoint '{}'".format(checkpoint_path)) return resume_epoch else: _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True): if cfg is None: cfg = getattr(model, 'default_cfg') if cfg is None or 'url' not in cfg or not cfg['url']: _logger.warning("Pretrained model URL is invalid, using random initialization.") return state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu') if filter_fn is not None: state_dict = filter_fn(state_dict) if in_chans == 1: conv1_name = cfg['first_conv'] _logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name) conv1_weight = state_dict[conv1_name + '.weight'] # Some weights are in torch.half, ensure it's float for sum on CPU conv1_type = conv1_weight.dtype conv1_weight = conv1_weight.float() O, I, J, K = conv1_weight.shape if I > 3: assert conv1_weight.shape[1] % 3 == 0 # For models with space2depth stems conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K) conv1_weight = conv1_weight.sum(dim=2, keepdim=False) else: conv1_weight = conv1_weight.sum(dim=1, keepdim=True) conv1_weight = conv1_weight.to(conv1_type) state_dict[conv1_name + '.weight'] = conv1_weight elif in_chans != 3: conv1_name = cfg['first_conv'] conv1_weight = state_dict[conv1_name + '.weight'] conv1_type = conv1_weight.dtype conv1_weight = conv1_weight.float() O, I, J, K = conv1_weight.shape if I != 3: _logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name) del state_dict[conv1_name + '.weight'] strict = False else: # NOTE this strategy should be better than random init, but there could be other combinations of # the original RGB input layer weights that'd work better for specific cases. _logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name) repeat = int(math.ceil(in_chans / 3)) conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] conv1_weight *= (3 / float(in_chans)) conv1_weight = conv1_weight.to(conv1_type) state_dict[conv1_name + '.weight'] = conv1_weight classifier_name = cfg['classifier'] if num_classes == 1000 and cfg['num_classes'] == 1001: # special case for imagenet trained models with extra background class in pretrained weights classifier_weight = state_dict[classifier_name + '.weight'] state_dict[classifier_name + '.weight'] = classifier_weight[1:] classifier_bias = state_dict[classifier_name + '.bias'] state_dict[classifier_name + '.bias'] = classifier_bias[1:] elif num_classes != cfg['num_classes']: # completely discard fully connected for all other differences between pretrained and created model del state_dict[classifier_name + '.weight'] del state_dict[classifier_name + '.bias'] strict = False model.load_state_dict(state_dict, strict=strict) def extract_layer(model, layer): layer = layer.split('.') module = model if hasattr(model, 'module') and layer[0] != 'module': module = model.module if not hasattr(model, 'module') and layer[0] == 'module': layer = layer[1:] for l in layer: if hasattr(module, l): if not l.isdigit(): module = getattr(module, l) else: module = module[int(l)] else: return module return module def set_layer(model, layer, val): layer = layer.split('.') module = model if hasattr(model, 'module') and layer[0] != 'module': module = model.module lst_index = 0 module2 = module for l in layer: if hasattr(module2, l): if not l.isdigit(): module2 = getattr(module2, l) else: module2 = module2[int(l)] lst_index += 1 lst_index -= 1 for l in layer[:lst_index]: if not l.isdigit(): module = getattr(module, l) else: module = module[int(l)] l = layer[lst_index] setattr(module, l, val) def adapt_model_from_string(parent_module, model_string): separator = '***' state_dict = {} lst_shape = model_string.split(separator) for k in lst_shape: k = k.split(':') key = k[0] shape = k[1][1:-1].split(',') if shape[0] != '': state_dict[key] = [int(i) for i in shape] new_module = deepcopy(parent_module) for n, m in parent_module.named_modules(): old_module = extract_layer(parent_module, n) if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): if isinstance(old_module, Conv2dSame): conv = Conv2dSame else: conv = nn.Conv2d s = state_dict[n + '.weight'] in_channels = s[1] out_channels = s[0] g = 1 if old_module.groups > 1: in_channels = out_channels g = in_channels new_conv = conv( in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size, bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, groups=g, stride=old_module.stride) set_layer(new_module, n, new_conv) if isinstance(old_module, nn.BatchNorm2d): new_bn = nn.BatchNorm2d( num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, affine=old_module.affine, track_running_stats=True) set_layer(new_module, n, new_bn) if isinstance(old_module, nn.Linear): # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? num_features = state_dict[n + '.weight'][1] new_fc = nn.Linear( in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) set_layer(new_module, n, new_fc) if hasattr(new_module, 'num_features'): new_module.num_features = num_features new_module.eval() parent_module.eval() return new_module def adapt_model_from_file(parent_module, model_variant): adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt') with open(adapt_file, 'r') as f: return adapt_model_from_string(parent_module, f.read().strip()) def build_model_with_cfg( model_cls: Callable, variant: str, pretrained: bool, default_cfg: dict, model_cfg: dict = None, feature_cfg: dict = None, pretrained_strict: bool = True, pretrained_filter_fn: Callable = None, **kwargs): pruned = kwargs.pop('pruned', False) features = False feature_cfg = feature_cfg or {} if kwargs.pop('features_only', False): features = True feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) if 'out_indices' in kwargs: feature_cfg['out_indices'] = kwargs.pop('out_indices') model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) model.default_cfg = deepcopy(default_cfg) if pruned: model = adapt_model_from_file(model, variant) if pretrained: load_pretrained( model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: feature_cls = FeatureListNet if 'feature_cls' in feature_cfg: feature_cls = feature_cfg.pop('feature_cls') if isinstance(feature_cls, str): feature_cls = feature_cls.lower() if 'hook' in feature_cls: feature_cls = FeatureHookNet else: assert False, f'Unknown feature class {feature_cls}' model = feature_cls(model, **feature_cfg) return model