TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from ..roberta.model_xlmr import XLMRModel
from fairseq.models.xmod.transformer_layer_xmod import XMODTransformerEncoderLayerBase
from ..roberta.model import base_architecture, RobertaEncoder
from fairseq.models.transformer import TransformerEncoder
from fairseq.modules.transformer_sentence_encoder import init_bert_params
from typing import Optional
from fairseq.models.xmod.hub_interface import XMODHubInterface
import torch
from fairseq.distributed import fsdp_wrap
from fairseq.models import (
register_model,
register_model_architecture,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
@register_model("xmod")
class XMODModel(XLMRModel):
@classmethod
def hub_models(cls):
return {
"xmod.base": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.81.1M.tar.gz",
"xmod.large.prenorm": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.large.prenorm.81.500k.tar.gz",
"xmod.base.13.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.13.125k.tar.gz",
"xmod.base.30.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.125k.tar.gz",
"xmod.base.30.195k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.30.195k.tar.gz",
"xmod.base.60.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.125k.tar.gz",
"xmod.base.60.265k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.60.265k.tar.gz",
"xmod.base.75.125k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.125k.tar.gz",
"xmod.base.75.269k": "https://dl.fbaipublicfiles.com/fairseq/models/xmod/xmod.base.75.269k.tar.gz",
}
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
bpe="sentencepiece",
**kwargs,
):
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
bpe=bpe,
load_checkpoint_heads=True,
**kwargs,
)
return XMODHubInterface(x["args"], x["task"], x["models"][0])
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
from omegaconf import OmegaConf
if OmegaConf.is_config(args):
OmegaConf.set_struct(args, False)
# make sure all arguments are present
base_architecture(args)
if not hasattr(args, "max_positions"):
if not hasattr(args, "tokens_per_sample"):
args.tokens_per_sample = task.max_positions()
args.max_positions = args.tokens_per_sample
encoder = XMODEncoder(args, task.source_dictionary)
if OmegaConf.is_config(args):
OmegaConf.set_struct(args, True)
return cls(args, encoder)
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
classification_head_name=None,
lang_id=None,
**kwargs,
):
if classification_head_name is not None:
features_only = True
x, extra = self.encoder(
src_tokens, features_only, return_all_hiddens, lang_id=lang_id, **kwargs
)
if classification_head_name is not None:
x = self.classification_heads[classification_head_name](x)
return x, extra
class XMODEncoder(RobertaEncoder):
"""XMOD encoder."""
def build_encoder(self, args, dictionary, embed_tokens):
encoder = XMODTransformerEncoder(args, dictionary, embed_tokens)
encoder.apply(init_bert_params)
return encoder
def forward(
self,
src_tokens,
features_only=False,
return_all_hiddens=False,
masked_tokens=None,
lang_id=None,
**unused,
):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
x, extra = self.extract_features(
src_tokens, return_all_hiddens=return_all_hiddens, lang_id=lang_id
)
if not features_only:
x = self.output_layer(x, masked_tokens=masked_tokens)
return x, extra
def extract_features(
self, src_tokens, return_all_hiddens=False, lang_id=None, **kwargs
):
encoder_out = self.sentence_encoder(
src_tokens,
return_all_hiddens=return_all_hiddens,
lang_id=lang_id,
token_embeddings=kwargs.get("token_embeddings", None),
)
# T x B x C -> B x T x C
features = encoder_out["encoder_out"][0].transpose(0, 1)
inner_states = encoder_out["encoder_states"] if return_all_hiddens else None
return features, {"inner_states": inner_states}
class XMODTransformerEncoder(TransformerEncoder):
def build_encoder_layer(self, cfg):
layer = XMODTransformerEncoderLayerBase(cfg)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
lang_id=None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return self.forward_scriptable(
src_tokens,
src_lengths,
return_all_hiddens,
token_embeddings,
lang_id=lang_id,
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
src_tokens,
src_lengths: Optional[torch.Tensor] = None,
return_all_hiddens: bool = False,
token_embeddings: Optional[torch.Tensor] = None,
lang_id=None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any()
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
# account for padding while computing the representation
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
# B x T x C -> T x B x C
x = x.transpose(0, 1)
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
# encoder layers
for layer in self.layers:
x = layer(
x,
encoder_padding_mask=encoder_padding_mask if has_pads else None,
lang_id=lang_id,
)
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
src_lengths = (
src_tokens.ne(self.padding_idx)
.sum(dim=1, dtype=torch.int32)
.reshape(-1, 1)
.contiguous()
)
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_embedding": [encoder_embedding], # B x T x C
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"src_lengths": [src_lengths],
}
@register_model_architecture("xmod", "xmod_base_13")
def roberta_base_architecture(args):
args.ffn_modules = getattr(args, "ffn_modules", False)
args.adapter_modules = getattr(args, "adapter_modules", True)
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False)
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True)
args.ln_before_adapter = getattr(args, "ln_before_adapter", True)
args.languages = getattr(
args,
"languages",
[
"ar_AR",
"en_XX",
"fi_FI",
"fr_XX",
"hi_IN",
"id_ID",
"ka_GE",
"ko_KR",
"ru_RU",
"sw_KE",
"ta_IN",
"th_TH",
"vi_VN",
],
)
base_architecture(args)
@register_model_architecture("xmod", "xmod_base_30")
def roberta_base_architecture(args):
args.ffn_modules = getattr(args, "ffn_modules", False)
args.adapter_modules = getattr(args, "adapter_modules", True)
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False)
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True)
args.ln_before_adapter = getattr(args, "ln_before_adapter", True)
args.languages = getattr(
args,
"languages",
[
"ar_AR",
"cs_CZ",
"en_XX",
"eu_ES",
"fi_FI",
"fr_XX",
"hi_IN",
"hr_HR",
"hu_HU",
"hy_AM",
"id_ID",
"it_IT",
"ka_GE",
"ko_KR",
"lt_LT",
"ml_IN",
"mn_MN",
"ms_MY",
"pl_PL",
"ro_RO",
"ru_RU",
"si_LK",
"sk_SK",
"sq_AL",
"sv_SE",
"sw_KE",
"ta_IN",
"th_TH",
"tl_XX",
"vi_VN",
],
)
base_architecture(args)
@register_model_architecture("xmod", "xmod_base_60")
def roberta_base_architecture(args):
args.ffn_modules = getattr(args, "ffn_modules", False)
args.adapter_modules = getattr(args, "adapter_modules", True)
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False)
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True)
args.ln_before_adapter = getattr(args, "ln_before_adapter", True)
args.languages = getattr(
args,
"languages",
[
"af_ZA",
"am_ET",
"ar_AR",
"be_BY",
"bn_IN",
"ca_ES",
"cs_CZ",
"cy_GB",
"da_DK",
"en_XX",
"eo_EO",
"et_EE",
"eu_ES",
"fa_IR",
"fi_FI",
"fr_XX",
"ga_IE",
"gl_ES",
"gu_IN",
"ha_NG",
"hi_IN",
"hr_HR",
"hu_HU",
"hy_AM",
"id_ID",
"is_IS",
"it_IT",
"ka_GE",
"ko_KR",
"ku_TR",
"la_VA",
"lt_LT",
"lv_LV",
"mk_MK",
"ml_IN",
"mn_MN",
"ms_MY",
"ne_NP",
"nl_XX",
"no_XX",
"pl_PL",
"ps_AF",
"pt_XX",
"ro_RO",
"ru_RU",
"sa_IN",
"sd_PK",
"si_LK",
"sk_SK",
"sl_SI",
"so_SO",
"sq_AL",
"sr_RS",
"sv_SE",
"sw_KE",
"ta_IN",
"te_IN",
"th_TH",
"tl_XX",
"vi_VN",
],
)
base_architecture(args)
@register_model_architecture("xmod", "xmod_base_75")
def roberta_base_architecture(args):
args.ffn_modules = getattr(args, "ffn_modules", False)
args.adapter_modules = getattr(args, "adapter_modules", True)
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False)
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True)
args.ln_before_adapter = getattr(args, "ln_before_adapter", True)
args.languages = getattr(
args,
"languages",
[
"af_ZA",
"am_ET",
"ar_AR",
"as_IN",
"be_BY",
"bn_IN",
"br_FR",
"bs_BA",
"ca_ES",
"cs_CZ",
"cy_GB",
"da_DK",
"en_XX",
"eo_EO",
"et_EE",
"eu_ES",
"fa_IR",
"fi_FI",
"fr_XX",
"fy_NL",
"ga_IE",
"gd_GB",
"gl_ES",
"gu_IN",
"ha_NG",
"hi_IN",
"hr_HR",
"hu_HU",
"hy_AM",
"id_ID",
"is_IS",
"it_IT",
"jv_ID",
"ka_GE",
"kn_IN",
"ko_KR",
"ku_TR",
"la_VA",
"lt_LT",
"lv_LV",
"mg_MG",
"mk_MK",
"ml_IN",
"mn_MN",
"mr_IN",
"ms_MY",
"ne_NP",
"nl_XX",
"no_XX",
"om_KE",
"or_IN",
"pa_IN",
"pl_PL",
"ps_AF",
"pt_XX",
"ro_RO",
"ru_RU",
"sa_IN",
"sd_PK",
"si_LK",
"sk_SK",
"sl_SI",
"so_SO",
"sq_AL",
"sr_RS",
"su_ID",
"sv_SE",
"sw_KE",
"ta_IN",
"te_IN",
"th_TH",
"tl_XX",
"vi_VN",
"xh_ZA",
"yi_DE",
],
)
base_architecture(args)
@register_model_architecture("xmod", "xmod_base")
def roberta_base_architecture(args):
args.ffn_modules = getattr(args, "ffn_modules", False)
args.adapter_modules = getattr(args, "adapter_modules", True)
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", False)
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", True)
args.ln_before_adapter = getattr(args, "ln_before_adapter", True)
args.languages = getattr(
args,
"languages",
[
"en_XX",
"id_ID",
"vi_VN",
"ru_RU",
"fa_IR",
"sv_SE",
"ja_XX",
"fr_XX",
"de_DE",
"ro_RO",
"ko_KR",
"hu_HU",
"es_XX",
"fi_FI",
"uk_UA",
"da_DK",
"pt_XX",
"no_XX",
"th_TH",
"pl_PL",
"bg_BG",
"nl_XX",
"zh_CN",
"he_IL",
"el_GR",
"it_IT",
"sk_SK",
"hr_HR",
"tr_TR",
"ar_AR",
"cs_CZ",
"lt_LT",
"hi_IN",
"zh_TW",
"ca_ES",
"ms_MY",
"sl_SI",
"lv_LV",
"ta_IN",
"bn_IN",
"et_EE",
"az_AZ",
"sq_AL",
"sr_RS",
"kk_KZ",
"ka_GE",
"tl_XX",
"ur_PK",
"is_IS",
"hy_AM",
"ml_IN",
"mk_MK",
"be_BY",
"la_VA",
"te_IN",
"eu_ES",
"gl_ES",
"mn_MN",
"kn_IN",
"ne_NP",
"sw_KE",
"si_LK",
"mr_IN",
"af_ZA",
"gu_IN",
"cy_GB",
"eo_EO",
"km_KH",
"ky_KG",
"uz_UZ",
"ps_AF",
"pa_IN",
"ga_IE",
"ha_NG",
"am_ET",
"lo_LA",
"ku_TR",
"so_SO",
"my_MM",
"or_IN",
"sa_IN",
],
)
base_architecture(args)
@register_model_architecture("xmod", "xmod_large_prenorm")
def roberta_base_architecture(args):
args.ffn_modules = getattr(args, "ffn_modules", False)
args.adapter_modules = getattr(args, "adapter_modules", True)
args.adapter_layer_norm = getattr(args, "adapter_layer_norm", True)
args.adapter_reuse_layer_norm = getattr(args, "adapter_reuse_layer_norm", False)
args.ln_before_adapter = getattr(args, "ln_before_adapter", False)
# args.bottleneck = getattr(args, "bottleneck", 8)
args.bottleneck = getattr(args, "bottleneck", 4)
args.languages = getattr(
args,
"languages",
[
"en_XX",
"id_ID",
"vi_VN",
"ru_RU",
"fa_IR",
"sv_SE",
"ja_XX",
"fr_XX",
"de_DE",
"ro_RO",
"ko_KR",
"hu_HU",
"es_XX",
"fi_FI",
"uk_UA",
"da_DK",
"pt_XX",
"no_XX",
"th_TH",
"pl_PL",
"bg_BG",
"nl_XX",
"zh_CN",
"he_IL",
"el_GR",
"it_IT",
"sk_SK",
"hr_HR",
"tr_TR",
"ar_AR",
"cs_CZ",
"lt_LT",
"hi_IN",
"zh_TW",
"ca_ES",
"ms_MY",
"sl_SI",
"lv_LV",
"ta_IN",
"bn_IN",
"et_EE",
"az_AZ",
"sq_AL",
"sr_RS",
"kk_KZ",
"ka_GE",
"tl_XX",
"ur_PK",
"is_IS",
"hy_AM",
"ml_IN",
"mk_MK",
"be_BY",
"la_VA",
"te_IN",
"eu_ES",
"gl_ES",
"mn_MN",
"kn_IN",
"ne_NP",
"sw_KE",
"si_LK",
"mr_IN",
"af_ZA",
"gu_IN",
"cy_GB",
"eo_EO",
"km_KH",
"ky_KG",
"uz_UZ",
"ps_AF",
"pa_IN",
"ga_IE",
"ha_NG",
"am_ET",
"lo_LA",
"ku_TR",
"so_SO",
"my_MM",
"or_IN",
"sa_IN",
],
)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
args.encoder_layers = getattr(args, "encoder_layers", 24)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
base_architecture(args)