CONQUER_RVMR / model /modeling_utils.py
Liangrj5
init
a638e43
"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
some functions are modified from HuggingFace
(https://github.com/huggingface/transformers)
"""
import torch
from torch import nn
import logging
logger = logging.getLogger(__name__)
def prune_linear_layer(layer, index, dim=0):
""" Prune a linear layer (a model parameters)
to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(
new_size[1], new_size[0], bias=layer.bias is not None).to(
layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def mask_logits(target, mask, eps=-1e4):
return target * mask + (1 - mask) * eps
def load_partial_checkpoint(checkpoint, n_layers, skip_layers=True):
if skip_layers:
new_checkpoint = {}
gap = int(12/n_layers)
prefix = "roberta.encoder.layer."
layer_range = {str(l): str(i) for i, l in enumerate(
list(range(gap-1, 12, gap)))}
for k, v in checkpoint.items():
if prefix in k:
layer_name = k.split(".")
layer_num = layer_name[3]
if layer_num in layer_range:
layer_name[3] = layer_range[layer_num]
new_layer_name = ".".join(layer_name)
new_checkpoint[new_layer_name] = v
else:
new_checkpoint[k] = v
else:
new_checkpoint = checkpoint
return new_checkpoint
def load_pretrained_weight(model, state_dict):
# Load from a PyTorch state_dict
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = ({} if metadata is None
else metadata.get(prefix[:-1], {}))
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys,
unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'roberta') and\
any(s.startswith('roberta.') for s in state_dict.keys()):
start_prefix = 'roberta.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from "
"pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in "
"{}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for '
'{}:\n\t{}'.format(
model.__class__.__name__,
"\n\t".join(error_msgs)))
return model
def pad_tensor_to_mul(tensor, dim=0, mul=8):
""" pad tensor to multiples (8 for tensor cores) """
t_size = list(tensor.size())
n_pad = mul - t_size[dim] % mul
if n_pad == mul:
n_pad = 0
padded_tensor = tensor
else:
t_size[dim] = n_pad
pad = torch.zeros(*t_size, dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.cat([tensor, pad], dim=dim)
return padded_tensor, n_pad