|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
from omegaconf import II |
|
|
|
from fairseq import options, utils |
|
from fairseq.dataclass import ChoiceEnum, FairseqDataclass |
|
from fairseq.models import ( |
|
FairseqLanguageModel, |
|
register_model, |
|
register_model_architecture, |
|
) |
|
from fairseq.models.transformer import ( |
|
DEFAULT_MIN_PARAMS_TO_WRAP, |
|
Embedding, |
|
TransformerDecoder, |
|
) |
|
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder |
|
from fairseq.utils import safe_getattr, safe_hasattr |
|
|
|
DEFAULT_MAX_TARGET_POSITIONS = 1024 |
|
|
|
|
|
@dataclass |
|
class TransformerLanguageModelConfig(FairseqDataclass): |
|
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( |
|
default="relu", metadata={"help": "activation function to use"} |
|
) |
|
dropout: float = field(default=0.1, metadata={"help": "dropout probability"}) |
|
attention_dropout: float = field( |
|
default=0.0, metadata={"help": "dropout probability for attention weights"} |
|
) |
|
activation_dropout: float = field( |
|
default=0.0, metadata={"help": "dropout probability after activation in FFN."} |
|
) |
|
relu_dropout: float = field( |
|
default=0.0, metadata={"help": "dropout probability after activation in FFN."} |
|
) |
|
decoder_embed_dim: int = field( |
|
default=512, metadata={"help": "decoder embedding dimension"} |
|
) |
|
decoder_output_dim: int = field( |
|
default=512, metadata={"help": "decoder output dimension"} |
|
) |
|
decoder_input_dim: int = field( |
|
default=512, metadata={"help": "decoder input dimension"} |
|
) |
|
decoder_ffn_embed_dim: int = field( |
|
default=2048, metadata={"help": "decoder embedding dimension for FFN"} |
|
) |
|
decoder_layers: int = field(default=6, metadata={"help": "num decoder layers"}) |
|
decoder_attention_heads: int = field( |
|
default=8, metadata={"help": "num decoder attention heads"} |
|
) |
|
decoder_normalize_before: bool = field( |
|
default=False, metadata={"help": "apply layernorm before each decoder block"} |
|
) |
|
no_decoder_final_norm: bool = field( |
|
default=False, |
|
metadata={"help": "don't add an extra layernorm after the last decoder block"}, |
|
) |
|
adaptive_softmax_cutoff: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "comma separated list of adaptive softmax cutoff points. " |
|
"Must be used with adaptive_loss criterion" |
|
}, |
|
) |
|
adaptive_softmax_dropout: float = field( |
|
default=0, |
|
metadata={"help": "sets adaptive softmax dropout for the tail projections"}, |
|
) |
|
adaptive_softmax_factor: float = field( |
|
default=4, metadata={"help": "adaptive input factor"} |
|
) |
|
no_token_positional_embeddings: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "if set, disables positional embeddings (outside self attention)" |
|
}, |
|
) |
|
share_decoder_input_output_embed: bool = field( |
|
default=False, metadata={"help": "share decoder input and output embeddings"} |
|
) |
|
character_embeddings: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "if set, uses character embedding convolutions to produce token embeddings" |
|
}, |
|
) |
|
character_filters: str = field( |
|
default="[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]", |
|
metadata={"help": "size of character embeddings"}, |
|
) |
|
character_embedding_dim: int = field( |
|
default=4, metadata={"help": "size of character embeddings"} |
|
) |
|
char_embedder_highway_layers: int = field( |
|
default=2, |
|
metadata={"help": "number of highway layers for character token embeddder"}, |
|
) |
|
adaptive_input: bool = field( |
|
default=False, metadata={"help": "if set, uses adaptive input"} |
|
) |
|
adaptive_input_factor: float = field( |
|
default=4, metadata={"help": "adaptive input factor"} |
|
) |
|
adaptive_input_cutoff: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "comma separated list of adaptive input cutoff points."}, |
|
) |
|
tie_adaptive_weights: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "if set, ties the weights of adaptive softmax and adaptive input" |
|
}, |
|
) |
|
tie_adaptive_proj: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "if set, ties the projection weights of adaptive softmax and adaptive input" |
|
}, |
|
) |
|
decoder_learned_pos: bool = field( |
|
default=False, |
|
metadata={"help": "use learned positional embeddings in the decoder"}, |
|
) |
|
layernorm_embedding: bool = field( |
|
default=False, metadata={"help": "add layernorm to embedding"} |
|
) |
|
no_scale_embedding: bool = field( |
|
default=False, metadata={"help": "if True, dont scale embeddings"} |
|
) |
|
checkpoint_activations: bool = field( |
|
default=False, metadata={"help": "checkpoint activations at each layer"} |
|
) |
|
offload_activations: bool = field( |
|
default=False, |
|
metadata={"help": "move checkpointed activations to CPU after they are used."}, |
|
) |
|
|
|
decoder_layerdrop: float = field( |
|
default=0.0, metadata={"help": "LayerDrop probability for decoder"} |
|
) |
|
decoder_layers_to_keep: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "which layers to *keep* when pruning as a comma-separated list" |
|
}, |
|
) |
|
|
|
quant_noise_pq: float = field( |
|
default=0.0, |
|
metadata={"help": "iterative PQ quantization noise at training time"}, |
|
) |
|
quant_noise_pq_block_size: int = field( |
|
default=8, |
|
metadata={"help": "block size of quantization noise at training time"}, |
|
) |
|
quant_noise_scalar: float = field( |
|
default=0.0, |
|
metadata={ |
|
"help": "scalar quantization noise and scalar quantization at training time" |
|
}, |
|
) |
|
|
|
min_params_to_wrap: int = field( |
|
default=DEFAULT_MIN_PARAMS_TO_WRAP, |
|
metadata={ |
|
"help": ( |
|
"minimum number of params for a layer to be wrapped with FSDP() when " |
|
"training with --ddp-backend=fully_sharded. Smaller values will " |
|
"improve memory efficiency, but may make torch.distributed " |
|
"communication less efficient due to smaller input sizes. This option " |
|
"is set to 0 (i.e., always wrap) when --checkpoint-activations or " |
|
"--offload-activations are passed." |
|
) |
|
}, |
|
) |
|
|
|
base_layers: Optional[int] = field( |
|
default=0, metadata={"help": "number of BASE layers in total"} |
|
) |
|
base_sublayers: Optional[int] = field( |
|
default=1, metadata={"help": "number of sublayers in each BASE layer"} |
|
) |
|
base_shuffle: Optional[int] = field( |
|
default=1, |
|
metadata={"help": "shuffle tokens between workers before computing assignment"}, |
|
) |
|
|
|
scale_fc: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Insert LayerNorm between fully connected layers"}, |
|
) |
|
scale_attn: Optional[bool] = field( |
|
default=False, metadata={"help": "Insert LayerNorm after attention"} |
|
) |
|
scale_heads: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Learn a scale coefficient for each attention head"}, |
|
) |
|
scale_resids: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Learn a scale coefficient for each residual connection"}, |
|
) |
|
|
|
|
|
decoder_xformers_att_config: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "config for xFormers library attention, defined in xformers.components.attention.AttentionConfig", |
|
}, |
|
) |
|
|
|
|
|
add_bos_token: bool = II("task.add_bos_token") |
|
tokens_per_sample: int = II("task.tokens_per_sample") |
|
max_target_positions: Optional[int] = II("task.max_target_positions") |
|
tpu: bool = II("common.tpu") |
|
|
|
|
|
@register_model("transformer_lm", dataclass=TransformerLanguageModelConfig) |
|
class TransformerLanguageModel(FairseqLanguageModel): |
|
@classmethod |
|
def hub_models(cls): |
|
def moses_fastbpe(path): |
|
return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} |
|
|
|
def spm(path): |
|
return {"path": path, "tokenizer": "space", "bpe": "sentencepiece"} |
|
|
|
return { |
|
"transformer_lm.gbw.adaptive_huge": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2", |
|
"transformer_lm.wiki103.adaptive": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2", |
|
"transformer_lm.wmt19.en": moses_fastbpe( |
|
"https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.bz2" |
|
), |
|
"transformer_lm.wmt19.de": moses_fastbpe( |
|
"https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.bz2" |
|
), |
|
"transformer_lm.wmt19.ru": moses_fastbpe( |
|
"https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2" |
|
), |
|
"transformer_lm.wmt20.en": spm( |
|
"https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.en.tar.gz" |
|
), |
|
"transformer_lm.wmt20.ta": spm( |
|
"https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.ta.tar.gz" |
|
), |
|
"transformer_lm.wmt20.iu.news": spm( |
|
"https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.news.tar.gz" |
|
), |
|
"transformer_lm.wmt20.iu.nh": spm( |
|
"https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.nh.tar.gz" |
|
), |
|
} |
|
|
|
def __init__(self, decoder): |
|
super().__init__(decoder) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
"""Build a new model instance.""" |
|
|
|
if args.decoder_layers_to_keep: |
|
args.decoder_layers = len(args.decoder_layers_to_keep.split(",")) |
|
|
|
if safe_getattr(args, "max_target_positions", None) is None: |
|
args.max_target_positions = safe_getattr( |
|
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS |
|
) |
|
|
|
if args.character_embeddings: |
|
embed_tokens = CharacterTokenEmbedder( |
|
task.source_dictionary, |
|
eval(args.character_filters), |
|
args.character_embedding_dim, |
|
args.decoder_embed_dim, |
|
args.char_embedder_highway_layers, |
|
) |
|
elif args.adaptive_input: |
|
embed_tokens = AdaptiveInput( |
|
len(task.source_dictionary), |
|
task.source_dictionary.pad(), |
|
args.decoder_input_dim, |
|
args.adaptive_input_factor, |
|
args.decoder_embed_dim, |
|
options.eval_str_list(args.adaptive_input_cutoff, type=int), |
|
args.quant_noise_pq, |
|
args.quant_noise_pq_block_size, |
|
) |
|
else: |
|
embed_tokens = cls.build_embedding( |
|
args, task.source_dictionary, args.decoder_input_dim |
|
) |
|
|
|
if args.tie_adaptive_weights: |
|
assert args.adaptive_input |
|
assert args.adaptive_input_factor == args.adaptive_softmax_factor |
|
assert ( |
|
args.adaptive_softmax_cutoff == args.adaptive_input_cutoff |
|
), "{} != {}".format( |
|
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff |
|
) |
|
assert args.decoder_input_dim == args.decoder_output_dim |
|
|
|
decoder = TransformerDecoder( |
|
args, task.target_dictionary, embed_tokens, no_encoder_attn=True |
|
) |
|
return cls(decoder) |
|
|
|
@classmethod |
|
def build_embedding(cls, args, dictionary, embed_dim, path=None): |
|
embed_tokens = Embedding(len(dictionary), embed_dim, dictionary.pad()) |
|
return embed_tokens |
|
|
|
|
|
def base_lm_architecture(args): |
|
|
|
if safe_hasattr(args, "no_tie_adaptive_proj"): |
|
|
|
|
|
args.no_decoder_final_norm = True |
|
if args.no_tie_adaptive_proj is False: |
|
args.tie_adaptive_proj = True |
|
if safe_hasattr(args, "decoder_final_norm"): |
|
args.no_decoder_final_norm = not args.decoder_final_norm |
|
|
|
args.dropout = safe_getattr(args, "dropout", 0.1) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.0) |
|
|
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 512) |
|
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 2048) |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 6) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 8) |
|
args.adaptive_softmax_cutoff = safe_getattr(args, "adaptive_softmax_cutoff", None) |
|
args.adaptive_softmax_dropout = safe_getattr(args, "adaptive_softmax_dropout", 0) |
|
args.adaptive_softmax_factor = safe_getattr(args, "adaptive_softmax_factor", 4) |
|
args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", False) |
|
args.activation_fn = safe_getattr(args, "activation_fn", "relu") |
|
|
|
args.decoder_layerdrop = safe_getattr(args, "decoder_layerdrop", 0) |
|
args.decoder_layers_to_keep = safe_getattr(args, "decoder_layers_to_keep", None) |
|
args.quant_noise_pq = safe_getattr(args, "quant_noise_pq", 0) |
|
args.quant_noise_pq_block_size = safe_getattr(args, "quant_noise_pq_block_size", 8) |
|
args.quant_noise_scalar = safe_getattr(args, "quant_noise_scalar", 0) |
|
|
|
args.base_layers = safe_getattr(args, "base_layers", 0) |
|
args.base_sublayers = safe_getattr(args, "base_sublayers", 1) |
|
args.base_shuffle = safe_getattr(args, "base_shuffle", False) |
|
|
|
args.add_bos_token = safe_getattr(args, "add_bos_token", False) |
|
args.no_token_positional_embeddings = safe_getattr( |
|
args, "no_token_positional_embeddings", False |
|
) |
|
args.share_decoder_input_output_embed = safe_getattr( |
|
args, "share_decoder_input_output_embed", False |
|
) |
|
args.character_embeddings = safe_getattr(args, "character_embeddings", False) |
|
|
|
args.decoder_output_dim = safe_getattr( |
|
args, "decoder_output_dim", args.decoder_embed_dim |
|
) |
|
args.decoder_input_dim = safe_getattr( |
|
args, "decoder_input_dim", args.decoder_embed_dim |
|
) |
|
|
|
|
|
args.decoder_normalize_before = True |
|
args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", False) |
|
|
|
args.adaptive_input = safe_getattr(args, "adaptive_input", False) |
|
args.adaptive_input_factor = safe_getattr(args, "adaptive_input_factor", 4) |
|
args.adaptive_input_cutoff = safe_getattr(args, "adaptive_input_cutoff", None) |
|
|
|
args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", False) |
|
args.tie_adaptive_proj = safe_getattr(args, "tie_adaptive_proj", False) |
|
|
|
args.no_scale_embedding = safe_getattr(args, "no_scale_embedding", False) |
|
args.layernorm_embedding = safe_getattr(args, "layernorm_embedding", False) |
|
args.checkpoint_activations = safe_getattr(args, "checkpoint_activations", False) |
|
args.offload_activations = safe_getattr(args, "offload_activations", False) |
|
args.scale_fc = safe_getattr(args, "scale_fc", False) |
|
args.scale_attn = safe_getattr(args, "scale_attn", False) |
|
args.scale_heads = safe_getattr(args, "scale_heads", False) |
|
args.scale_resids = safe_getattr(args, "scale_resids", False) |
|
if args.offload_activations: |
|
args.checkpoint_activations = True |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_big") |
|
def transformer_lm_big(args): |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 12) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1024) |
|
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 4096) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_wiki103") |
|
@register_model_architecture("transformer_lm", "transformer_lm_baevski_wiki103") |
|
def transformer_lm_baevski_wiki103(args): |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 16) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 8) |
|
args.dropout = safe_getattr(args, "dropout", 0.3) |
|
args.adaptive_input = safe_getattr(args, "adaptive_input", True) |
|
args.tie_adaptive_weights = safe_getattr(args, "tie_adaptive_weights", True) |
|
args.adaptive_input_cutoff = safe_getattr( |
|
args, "adaptive_input_cutoff", "20000,60000" |
|
) |
|
args.adaptive_softmax_cutoff = safe_getattr( |
|
args, "adaptive_softmax_cutoff", "20000,60000" |
|
) |
|
args.adaptive_softmax_dropout = safe_getattr(args, "adaptive_softmax_dropout", 0.2) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) |
|
args.activation_dropout = safe_getattr(args, "activation_dropout", 0.1) |
|
args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", True) |
|
args.tie_adaptive_proj = safe_getattr(args, "tie_adaptive_proj", True) |
|
transformer_lm_big(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gbw") |
|
@register_model_architecture("transformer_lm", "transformer_lm_baevski_gbw") |
|
def transformer_lm_baevski_gbw(args): |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 512) |
|
args.dropout = safe_getattr(args, "dropout", 0.1) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) |
|
args.no_decoder_final_norm = safe_getattr(args, "no_decoder_final_norm", True) |
|
transformer_lm_big(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt") |
|
def transformer_lm_gpt(args): |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 768) |
|
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 3072) |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 12) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 12) |
|
args.dropout = safe_getattr(args, "dropout", 0.1) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) |
|
args.activation_fn = safe_getattr(args, "activation_fn", "gelu") |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_small") |
|
def transformer_lm_gpt2_small(args): |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1024) |
|
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 4096) |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 24) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) |
|
args.dropout = safe_getattr(args, "dropout", 0.1) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) |
|
args.activation_fn = safe_getattr(args, "activation_fn", "gelu") |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_tiny") |
|
def transformer_lm_gpt2_tiny(args): |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 64) |
|
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 64) |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 2) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 1) |
|
args.dropout = safe_getattr(args, "dropout", 0.1) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) |
|
args.activation_fn = safe_getattr(args, "activation_fn", "gelu") |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_medium") |
|
def transformer_lm_gpt2_medium(args): |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1280) |
|
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 5120) |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 36) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 20) |
|
args.dropout = safe_getattr(args, "dropout", 0.1) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) |
|
args.activation_fn = safe_getattr(args, "activation_fn", "gelu") |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_big") |
|
def transformer_lm_gpt2_big(args): |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1600) |
|
args.decoder_ffn_embed_dim = safe_getattr(args, "decoder_ffn_embed_dim", 6400) |
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 48) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 25) |
|
args.dropout = safe_getattr(args, "dropout", 0.1) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.1) |
|
args.activation_fn = safe_getattr(args, "activation_fn", "gelu") |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_big_wide") |
|
def transformer_lm_gpt2_big_wide(args): |
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) |
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) |
|
args.decoder_layers = getattr(args, "decoder_layers", 24) |
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) |
|
args.dropout = getattr(args, "dropout", 0.1) |
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1) |
|
args.activation_fn = getattr(args, "activation_fn", "gelu") |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt2_bigger") |
|
def transformer_lm_gpt2_bigger(args): |
|
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 2048) |
|
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 8192) |
|
args.decoder_layers = getattr(args, "decoder_layers", 48) |
|
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 32) |
|
args.dropout = getattr(args, "dropout", 0.1) |
|
args.attention_dropout = getattr(args, "attention_dropout", 0.1) |
|
args.activation_fn = getattr(args, "activation_fn", "gelu") |
|
base_lm_architecture(args) |
|
|
|
|
|
def base_gpt3_architecture(args): |
|
args.decoder_input_dim = args.decoder_embed_dim |
|
args.decoder_output_dim = args.decoder_embed_dim |
|
args.decoder_ffn_embed_dim = safe_getattr( |
|
args, "decoder_ffn_embed_dim", args.decoder_embed_dim * 4 |
|
) |
|
|
|
args.decoder_learned_pos = safe_getattr(args, "decoder_learned_pos", True) |
|
args.dropout = safe_getattr(args, "dropout", 0.0) |
|
args.attention_dropout = safe_getattr(args, "attention_dropout", 0.0) |
|
args.activation_fn = safe_getattr(args, "activation_fn", "gelu") |
|
args.share_decoder_input_output_embed = True |
|
base_lm_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_small") |
|
def transformer_lm_gpt3_small(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 12) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 768) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 12) |
|
base_gpt3_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_medium") |
|
def transformer_lm_gpt3_medium(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 24) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1024) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) |
|
base_gpt3_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_large") |
|
def transformer_lm_gpt3_large(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 24) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 1536) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 16) |
|
base_gpt3_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_xl") |
|
def transformer_lm_gpt3_xl(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 24) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 2048) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 32) |
|
base_gpt3_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_2_7") |
|
def transformer_lm_gpt3_2_7(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 32) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 2560) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 32) |
|
base_gpt3_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_6_7") |
|
def transformer_lm_gpt3_6_7(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 32) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 4096) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 32) |
|
base_gpt3_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_13") |
|
def transformer_lm_gpt3_13(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 40) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 5120) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 40) |
|
base_gpt3_architecture(args) |
|
|
|
|
|
@register_model_architecture("transformer_lm", "transformer_lm_gpt3_175") |
|
def transformer_lm_gpt3_175(args): |
|
|
|
args.decoder_layers = safe_getattr(args, "decoder_layers", 96) |
|
args.decoder_embed_dim = safe_getattr(args, "decoder_embed_dim", 12288) |
|
args.decoder_attention_heads = safe_getattr(args, "decoder_attention_heads", 96) |
|
base_gpt3_architecture(args) |
|
|