TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderModel,
register_model,
register_model_architecture,
)
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, TransformerEncoder
from fairseq.modules import LayerNorm
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from fairseq.utils import safe_getattr, safe_hasattr
from .hub_interface import RobertaHubInterface
logger = logging.getLogger(__name__)
@register_model("roberta")
class RobertaModel(FairseqEncoderModel):
@classmethod
def hub_models(cls):
return {
"roberta.base": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz",
"roberta.large": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz",
"roberta.large.mnli": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz",
"roberta.large.wsc": "http://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz",
}
def __init__(self, args, encoder):
super().__init__(encoder)
self.args = args
# We follow BERT's random weight initialization
self.apply(init_bert_params)
self.classification_heads = nn.ModuleDict()
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--encoder-layers", type=int, metavar="L", help="num encoder layers"
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="H",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="F",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="A",
help="num encoder attention heads",
)
parser.add_argument(
"--activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use",
)
parser.add_argument(
"--pooler-activation-fn",
choices=utils.get_available_activation_fns(),
help="activation function to use for pooler layer",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--layernorm-embedding",
action="store_true",
help="add layernorm to embedding",
)
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--activation-dropout",
type=float,
metavar="D",
help="dropout probability after activation in FFN",
)
parser.add_argument(
"--pooler-dropout",
type=float,
metavar="D",
help="dropout probability in the masked_lm pooler layers",
)
parser.add_argument(
"--max-positions", type=int, help="number of positional embeddings to learn"
)
parser.add_argument(
"--load-checkpoint-heads",
action="store_true",
help="(re-)register and load heads when loading checkpoints",
)
parser.add_argument(
"--untie-weights-roberta",
action="store_true",
help="Untie weights between embeddings and classifiers in RoBERTa",
)
# args for "Reducing Transformer Depth on Demand with Structured Dropout" (Fan et al., 2019)
parser.add_argument(
"--encoder-layerdrop",
type=float,
metavar="D",
default=0,
help="LayerDrop probability for encoder",
)
parser.add_argument(
"--encoder-layers-to-keep",
default=None,
help="which layers to *keep* when pruning as a comma-separated list",
)
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
parser.add_argument(
"--quant-noise-pq",
type=float,
metavar="D",
default=0,
help="iterative PQ quantization noise at training time",
)
parser.add_argument(
"--quant-noise-pq-block-size",
type=int,
metavar="D",
default=8,
help="block size of quantization noise at training time",
)
parser.add_argument(
"--quant-noise-scalar",
type=float,
metavar="D",
default=0,
help="scalar quantization noise and scalar quantization at training time",
)
# args for "Better Fine-Tuning by Reducing Representational Collapse" (Aghajanyan et al. 2020)
parser.add_argument(
"--spectral-norm-classification-head",
action="store_true",
default=False,
help="Apply spectral normalization on the classification head",
)
# args for Fully Sharded Data Parallel (FSDP) training
parser.add_argument(
"--min-params-to-wrap",
type=int,
metavar="D",
default=DEFAULT_MIN_PARAMS_TO_WRAP,
help=(
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
),
)
# args for AdaPruning
# In short, it adds regularizarion for the multihead attention module and feed forward neural nets
# For more details, please refer to the paper https://openreview.net/forum?id=_CMSV7FTzGI
parser.add_argument(
"--mha-reg-scale-factor",
type=float,
metavar="D",
default=0.0,
help="scaling factor for regularization term in adptive pruning, recommendation is 0.000375",
)
parser.add_argument(
"--ffn-reg-scale-factor",
type=float,
metavar="D",
default=0.0,
help="scaling factor for regularization term in adptive pruning, recommendation is 0.000375",
)
parser.add_argument(
"--mha-heads-to-keep",
type=int,
metavar="D",
default=-1,
help="number of heads to keep in each multi-head attention module, -1 means keeping all heads",
)
parser.add_argument(
"--ffn-blocks-to-remove",
type=int,
metavar="D",
default=-1,
help="number of feedforward blocks to remove in each transformer layer, -1 means keeping all ffn blocks",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
from omegaconf import OmegaConf
if OmegaConf.is_config(args):
OmegaConf.set_struct(args, False)
# make sure all arguments are present
base_architecture(args)
if not safe_hasattr(args, "max_positions"):
if not safe_hasattr(args, "tokens_per_sample"):
args.tokens_per_sample = task.max_positions()
args.max_positions = args.tokens_per_sample
encoder = RobertaEncoder(args, task.source_dictionary)
if OmegaConf.is_config(args):
OmegaConf.set_struct(args, True)
return cls(args, encoder)
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
**kwargs,
):
if classification_head_name is not None:
features_only = True
x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
def _get_adaptive_head_loss(self):
norm_loss = 0
scaling = float(self.args.mha_reg_scale_factor)
for layer in self.encoder.sentence_encoder.layers:
norm_loss_layer = 0
for i in range(layer.self_attn.num_heads):
start_idx = i * layer.self_attn.head_dim
end_idx = (i + 1) * layer.self_attn.head_dim
norm_loss_layer += scaling * (
torch.sum(
torch.abs(
layer.self_attn.q_proj.weight[
start_idx:end_idx,
]
)
)
+ torch.sum(
torch.abs(layer.self_attn.q_proj.bias[start_idx:end_idx])
)
)
norm_loss_layer += scaling * (
torch.sum(
torch.abs(
layer.self_attn.k_proj.weight[
start_idx:end_idx,
]
)
)
+ torch.sum(
torch.abs(layer.self_attn.k_proj.bias[start_idx:end_idx])
)
)
norm_loss_layer += scaling * (
torch.sum(
torch.abs(
layer.self_attn.v_proj.weight[
start_idx:end_idx,
]
)
)
+ torch.sum(
torch.abs(layer.self_attn.v_proj.bias[start_idx:end_idx])
)
)
norm_loss += norm_loss_layer
return norm_loss
def _get_adaptive_ffn_loss(self):
ffn_scale_factor = float(self.args.ffn_reg_scale_factor)
filter_loss = 0
for layer in self.encoder.sentence_encoder.layers:
filter_loss += torch.sum(
torch.abs(layer.fc1.weight * ffn_scale_factor)
) + torch.sum(torch.abs(layer.fc2.weight * ffn_scale_factor))
filter_loss += torch.sum(
torch.abs(layer.fc1.bias * ffn_scale_factor)
) + torch.sum(torch.abs(layer.fc2.bias * ffn_scale_factor))
return filter_loss
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
def register_classification_head(
self, name, num_classes=None, inner_dim=None, **kwargs
):
"""Register a classification head."""
if name in self.classification_heads:
prev_num_classes = self.classification_heads[name].out_proj.out_features
prev_inner_dim = self.classification_heads[name].dense.out_features
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
logger.warning(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})".format(
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
)
)
self.classification_heads[name] = RobertaClassificationHead(
input_dim=self.args.encoder_embed_dim,
inner_dim=inner_dim or self.args.encoder_embed_dim,
num_classes=num_classes,
activation_fn=self.args.pooler_activation_fn,
pooler_dropout=self.args.pooler_dropout,
q_noise=self.args.quant_noise_pq,
qn_block_size=self.args.quant_noise_pq_block_size,
do_spectral_norm=self.args.spectral_norm_classification_head,
)
@property
def supported_targets(self):
return {"self"}
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="gpt2",
**kwargs,
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
logger.info(x["args"])
return RobertaHubInterface(x["args"], x["task"], x["models"][0])
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
# rename decoder -> encoder before upgrading children modules
for k in list(state_dict.keys()):
if k.startswith(prefix + "decoder"):
new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
state_dict[new_k] = state_dict[k]
del state_dict[k]
# rename emb_layer_norm -> layernorm_embedding
for k in list(state_dict.keys()):
if ".emb_layer_norm." in k:
new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.")
state_dict[new_k] = state_dict[k]
del state_dict[k]
# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)
# Handle new classification heads present in the state dict.
current_head_names = (
[]
if not hasattr(self, "classification_heads")
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + "classification_heads."):
continue
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
num_classes = state_dict[
prefix + "classification_heads." + head_name + ".out_proj.weight"
].size(0)
inner_dim = state_dict[
prefix + "classification_heads." + head_name + ".dense.weight"
].size(0)
if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
"deleting classification head ({}) from checkpoint "
"not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes
!= self.classification_heads[head_name].out_proj.out_features
or inner_dim
!= self.classification_heads[head_name].dense.out_features
):
logger.warning(
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}".format(
head_name, k
)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + "classification_heads." + k not in state_dict:
logger.info("Overwriting " + prefix + "classification_heads." + k)
state_dict[prefix + "classification_heads." + k] = v
# adapt data2vec models
if (
"encoder._ema" in state_dict
and "encoder.lm_head.weight" not in state_dict
):
lm_state = self.encoder.lm_head.state_dict()
for k, v in lm_state.items():
state_dict["encoder.lm_head." + k] = v
for k in list(state_dict.keys()):
if k.startswith("encoder.regression_head") or k == "encoder._ema":
del state_dict[k]
class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.layer_norm = LayerNorm(embed_dim)
if weight is None:
weight = nn.Linear(embed_dim, output_dim, bias=False).weight
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features, masked_tokens=None, **kwargs):
# Only project the masked tokens while training,
# saves both memory and computation
if masked_tokens is not None:
features = features[masked_tokens, :]
x = self.dense(features)
x = self.activation_fn(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(
self,
input_dim,
inner_dim,
num_classes,
activation_fn,
pooler_dropout,
q_noise=0,
qn_block_size=8,
do_spectral_norm=False,
):
super().__init__()
self.dense = nn.Linear(input_dim, inner_dim)
self.activation_fn = utils.get_activation_fn(activation_fn)
self.dropout = nn.Dropout(p=pooler_dropout)
self.out_proj = apply_quant_noise_(
nn.Linear(inner_dim, num_classes), q_noise, qn_block_size
)
if do_spectral_norm:
if q_noise != 0:
raise NotImplementedError(
"Attempting to use Spectral Normalization with Quant Noise. This is not officially supported"
)
self.out_proj = torch.nn.utils.spectral_norm(self.out_proj)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = self.activation_fn(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class RobertaEncoder(FairseqEncoder):
"""RoBERTa encoder."""
def __init__(self, args, dictionary):
super().__init__(dictionary)
# set any missing default values
base_architecture(args)
self.args = args
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
embed_tokens = self.build_embedding(
len(dictionary), args.encoder_embed_dim, dictionary.pad()
)
self.sentence_encoder = self.build_encoder(args, dictionary, embed_tokens)
self.lm_head = self.build_lm_head(
embed_dim=args.encoder_embed_dim,
output_dim=len(dictionary),
activation_fn=args.activation_fn,
weight=(
self.sentence_encoder.embed_tokens.weight
if not args.untie_weights_roberta
else None
),
)
def build_embedding(self, vocab_size, embedding_dim, padding_idx):
return nn.Embedding(vocab_size, embedding_dim, padding_idx)
def build_encoder(self, args, dictionary, embed_tokens):
encoder = TransformerEncoder(args, dictionary, embed_tokens)
encoder.apply(init_bert_params)
return encoder
def build_lm_head(self, embed_dim, output_dim, activation_fn, weight):
return RobertaLMHead(embed_dim, output_dim, activation_fn, weight)
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
masked_tokens=None,
**unused,
):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
x, extra = self.extract_features(
src_tokens, return_all_hiddens=return_all_hiddens
)
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
def extract_features(self, src_tokens, return_all_hiddens=False, **kwargs):
encoder_out = self.sentence_encoder(
src_tokens,
return_all_hiddens=return_all_hiddens,
token_embeddings=kwargs.get("token_embeddings", None),
)
# T x B x C -> B x T x C
features = encoder_out["encoder_out"][0].transpose(0, 1)
inner_states = encoder_out["encoder_states"] if return_all_hiddens else None
return features, {"inner_states": inner_states}
def output_layer(self, features, masked_tokens=None, **unused):
return self.lm_head(features, masked_tokens)
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
@register_model_architecture("roberta", "roberta")
def base_architecture(args):
args.encoder_layers = safe_getattr(args, "encoder_layers", 12)
args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 768)
args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 3072)
args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 12)
args.dropout = safe_getattr(args, "dropout", 0.1)
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1)
args.activation_dropout = safe_getattr(args, "activation_dropout", 0.0)
args.pooler_dropout = safe_getattr(args, "pooler_dropout", 0.0)
args.max_source_positions = safe_getattr(args, "max_positions", 512)
args.no_token_positional_embeddings = safe_getattr(
args, "no_token_positional_embeddings", False
)
# BERT has a few structural differences compared to the original Transformer
args.encoder_learned_pos = safe_getattr(args, "encoder_learned_pos", True)
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", True)
args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", True)
args.activation_fn = safe_getattr(args, "activation_fn", "gelu")
args.encoder_normalize_before = safe_getattr(
args, "encoder_normalize_before", False
)
args.pooler_activation_fn = safe_getattr(args, "pooler_activation_fn", "tanh")
args.untie_weights_roberta = safe_getattr(args, "untie_weights_roberta", False)
# Adaptive input config
args.adaptive_input = safe_getattr(args, "adaptive_input", False)
# LayerDrop config
args.encoder_layerdrop = safe_getattr(args, "encoder_layerdrop", 0.0)
args.encoder_layers_to_keep = safe_getattr(args, "encoder_layers_to_keep", None)
# Quantization noise config
args.quant_noise_pq = safe_getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = safe_getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = safe_getattr(args, "quant_noise_scalar", 0)
# R4F config
args.spectral_norm_classification_head = safe_getattr(
args, "spectral_norm_classification_head", False
)
@register_model_architecture("roberta", "roberta_prenorm")
def roberta_prenorm_architecture(args):
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False)
args.encoder_normalize_before = safe_getattr(args, "encoder_normalize_before", True)
base_architecture(args)
@register_model_architecture("roberta", "roberta_base")
def roberta_base_architecture(args):
base_architecture(args)
@register_model_architecture("roberta", "roberta_large")
def roberta_large_architecture(args):
args.encoder_layers = safe_getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 16)
base_architecture(args)
@register_model_architecture("roberta", "xlm")
def xlm_architecture(args):
args.encoder_layers = safe_getattr(args, "encoder_layers", 16)
args.encoder_embed_dim = safe_getattr(args, "encoder_embed_dim", 1280)
args.encoder_ffn_embed_dim = safe_getattr(args, "encoder_ffn_embed_dim", 1280 * 4)
args.encoder_attention_heads = safe_getattr(args, "encoder_attention_heads", 16)
base_architecture(args)