TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder,
register_model,
register_model_architecture,
)
from fairseq.modules import (
AdaptiveSoftmax,
DynamicConv_scripatable as DynamicConv,
FairseqDropout,
LayerNorm,
LightweightConv,
MultiheadAttention,
PositionalEmbedding,
)
from fairseq.utils import safe_hasattr
from torch import Tensor
@register_model("lightconv")
class LightConvModel(FairseqEncoderDecoderModel):
"""
LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019)
<https://openreview.net/pdf?id=SkVhlh09tX>`_.
To use LightConv please set ``--encoder-conv-type lightweight --decoder-conv-type lightweight``
To use DynamicConv please set ``--encoder-conv-type dynamic --decoder-conv-type dynamic``
Args:
encoder (LightConvEncoder): the encoder
decoder (LightConvDecoder): the decoder
The LightConv model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.lightconv_parser
:prog:
"""
@classmethod
def hub_models(cls):
# fmt: off
def moses_subword(path):
return {
'path': path,
'tokenizer': 'moses',
'bpe': 'subword_nmt',
}
return {
'lightconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.gz'),
'dynamicconv.no_glu.iwslt14.de-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.gz'),
'lightconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.gz'),
'dynamicconv.no_glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.gz'),
'lightconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'),
'dynamicconv.glu.wmt16.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'),
'lightconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz'),
'dynamicconv.glu.wmt17.en-de': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz'),
'lightconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.gz'),
'dynamicconv.glu.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.gz'),
'lightconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.gz'),
'dynamicconv.glu.wmt17.zh-en': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.gz'),
}
# fmt: on
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
"--dropout", type=float, metavar="D", help="dropout probability"
)
parser.add_argument(
"--attention-dropout",
type=float,
metavar="D",
help="dropout probability for attention weights",
)
parser.add_argument(
"--relu-dropout",
type=float,
metavar="D",
help="dropout probability after ReLU in FFN",
)
parser.add_argument(
"--input-dropout",
type=float,
metavar="D",
help="dropout probability of the inputs",
)
parser.add_argument(
"--encoder-embed-path",
type=str,
metavar="STR",
help="path to pre-trained encoder embedding",
)
parser.add_argument(
"--encoder-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-conv-dim",
type=int,
metavar="N",
help="encoder embedding dimension",
)
parser.add_argument(
"--encoder-ffn-embed-dim",
type=int,
metavar="N",
help="encoder embedding dimension for FFN",
)
parser.add_argument(
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
)
parser.add_argument(
"--encoder-attention-heads",
type=int,
metavar="N",
help="num encoder attention heads or LightConv/DynamicConv heads",
)
parser.add_argument(
"--encoder-normalize-before",
action="store_true",
help="apply layernorm before each encoder block",
)
parser.add_argument(
"--encoder-learned-pos",
action="store_true",
help="use learned positional embeddings in the encoder",
)
parser.add_argument(
"--decoder-embed-path",
type=str,
metavar="STR",
help="path to pre-trained decoder embedding",
)
parser.add_argument(
"--decoder-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-conv-dim",
type=int,
metavar="N",
help="decoder embedding dimension",
)
parser.add_argument(
"--decoder-ffn-embed-dim",
type=int,
metavar="N",
help="decoder embedding dimension for FFN",
)
parser.add_argument(
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
)
parser.add_argument(
"--decoder-attention-heads",
type=int,
metavar="N",
help="num decoder attention heads or LightConv/DynamicConv heads",
)
parser.add_argument(
"--decoder-learned-pos",
action="store_true",
help="use learned positional embeddings in the decoder",
)
parser.add_argument(
"--decoder-normalize-before",
action="store_true",
help="apply layernorm before each decoder block",
)
parser.add_argument(
"--share-decoder-input-output-embed",
action="store_true",
help="share decoder input and output embeddings",
)
parser.add_argument(
"--share-all-embeddings",
action="store_true",
help="share encoder, decoder and output embeddings"
" (requires shared dictionary and embed dim)",
)
parser.add_argument(
"--adaptive-softmax-cutoff",
metavar="EXPR",
help="comma separated list of adaptive softmax cutoff points. "
"Must be used with adaptive_loss criterion",
),
parser.add_argument(
"--adaptive-softmax-dropout",
type=float,
metavar="D",
help="sets adaptive softmax dropout for the tail projections",
)
"""LightConv and DynamicConv arguments"""
parser.add_argument(
"--encoder-kernel-size-list",
type=lambda x: utils.eval_str_list(x, int),
help='list of kernel size (default: "[3,7,15,31,31,31,31]")',
)
parser.add_argument(
"--decoder-kernel-size-list",
type=lambda x: utils.eval_str_list(x, int),
help='list of kernel size (default: "[3,7,15,31,31,31]")',
)
parser.add_argument(
"--encoder-glu", type=utils.eval_bool, help="glu after in proj"
)
parser.add_argument(
"--decoder-glu", type=utils.eval_bool, help="glu after in proj"
)
parser.add_argument(
"--encoder-conv-type",
default="dynamic",
type=str,
choices=["dynamic", "lightweight"],
help="type of convolution",
)
parser.add_argument(
"--decoder-conv-type",
default="dynamic",
type=str,
choices=["dynamic", "lightweight"],
help="type of convolution",
)
parser.add_argument("--weight-softmax", default=True, type=utils.eval_bool)
parser.add_argument(
"--weight-dropout",
type=float,
metavar="D",
help="dropout probability for conv weights",
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture(args)
if not safe_hasattr(args, "max_source_positions"):
args.max_source_positions = 1024
if not safe_hasattr(args, "max_target_positions"):
args.max_target_positions = 1024
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise RuntimeError(
"--share-all-embeddings requires a joined dictionary"
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise RuntimeError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = LightConvEncoder(args, src_dict, encoder_embed_tokens)
decoder = LightConvDecoder(args, tgt_dict, decoder_embed_tokens)
return LightConvModel(encoder, decoder)
def forward(
self,
src_tokens: Tensor,
src_lengths: Tensor,
prev_output_tokens: Tensor,
):
"""
(The forward method inherited from the base class has a **kwargs
argument in its input, which is not supported in torchscript. This
method overwrites the forward method definition without **kwargs.)
Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the
encoder output and previous decoder outputs (i.e., teacher forcing) to
the decoder to produce the next outputs::
encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out)
return decoder_out
class LightConvEncoder(FairseqEncoder):
"""
LightConv encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`LightConvEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
args.max_source_positions,
embed_dim,
self.padding_idx,
learned=args.encoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
self.layers = nn.ModuleList([])
self.layers.extend(
[
LightConvEncoderLayer(
args, kernel_size=args.encoder_kernel_size_list[i]
)
for i in range(args.encoder_layers)
]
)
self.register_buffer("version", torch.Tensor([2]))
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(
self, src_tokens: Tensor, src_lengths: Optional[Tensor] = None
) -> Dict[str, List[Tensor]]:
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
"""
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx) # B x T
if not encoder_padding_mask.any():
encoder_mask = None
else:
encoder_mask = encoder_padding_mask
# encoder layers
for layer in self.layers:
x = layer(x, encoder_mask)
if self.layer_norm is not None:
x = self.layer_norm(x)
output_dict: Dict[str, List[Tensor]] = {}
if src_lengths is not None:
output_dict["src_lengths"] = [src_lengths]
output_dict["encoder_out"] = [x] # T x B x C
if encoder_mask is not None:
output_dict["encoder_padding_mask"] = [encoder_mask] # B x T
return output_dict
@torch.jit.export
def reorder_encoder_out(
self, encoder_out: Dict[str, List[Tensor]], new_order: Tensor
):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
encoder = []
else:
encoder = [encoder_out["encoder_out"][0].index_select(1, new_order)]
output_dict = {"encoder_out": encoder}
if ("encoder_padding_mask" not in encoder_out) or (
len(encoder_out["encoder_padding_mask"]) == 0
):
encoder_padding_mask = []
else:
encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
output_dict["encoder_padding_mask"] = encoder_padding_mask
return output_dict
def max_positions(self):
"""Maximum input length supported by the encoder."""
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions)
class LightConvDecoder(FairseqIncrementalDecoder):
"""
LightConv decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`LightConvDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
"""
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True
):
super().__init__(dictionary)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
output_embed_dim = args.decoder_output_dim
padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.max_target_positions,
embed_dim,
padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
self.layers = nn.ModuleList([])
self.layers.extend(
[
LightConvDecoderLayer(
args,
no_encoder_attn,
kernel_size=args.decoder_kernel_size_list[i],
dictionary=dictionary,
)
for i in range(args.decoder_layers)
]
)
self.adaptive_softmax = None
self.output_projection = None
self.project_out_dim = (
Linear(embed_dim, output_embed_dim, bias=False)
if embed_dim != output_embed_dim and not args.tie_adaptive_weights
else None
)
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
output_embed_dim,
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=output_embed_dim**-0.5
)
self.register_buffer("version", torch.Tensor([2]))
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(
self,
prev_output_tokens: Tensor,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
src_lengths: Optional[Any] = None,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
# embed positions
positions = (
self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens.contiguous())
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
attn = None
inner_states: List[Optional[Tensor]] = [x]
# decoder layers
attn: Optional[Tensor] = None
for layer in self.layers:
encoder: Optional[Tensor] = None
encoder_padding_mask: Optional[Tensor] = None
if encoder_out is not None:
if len(encoder_out["encoder_out"]) > 0:
encoder = encoder_out["encoder_out"][0]
if (
"encoder_padding_mask" in encoder_out
and len(encoder_out["encoder_padding_mask"]) > 0
):
encoder_padding_mask = encoder_out["encoder_padding_mask"][0]
x, attn = layer(
x,
encoder,
encoder_padding_mask,
incremental_state,
)
inner_states.append(x)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if self.adaptive_softmax is None:
# project back to size of vocabulary
x = self.output_projection(x)
return x, {"attn": [attn], "inner_states": inner_states}
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if (
not hasattr(self, "_future_mask")
or self._future_mask is None
or self._future_mask.device != tensor.device
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(
utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
)
return self._future_mask[:dim, :dim]
class LightConvEncoderLayer(nn.Module):
"""Encoder layer block.
Args:
args (argparse.Namespace): parsed command-line arguments
kernel_size: kernel size of the convolution
"""
def __init__(self, args, kernel_size=0):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.conv_dim = args.encoder_conv_dim
padding_l = (
kernel_size // 2
if kernel_size % 2 == 1
else ((kernel_size - 1) // 2, kernel_size // 2)
)
if args.encoder_glu:
self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim)
self.act = nn.GLU()
else:
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
if args.encoder_conv_type == "lightweight":
self.conv = LightweightConv(
self.conv_dim,
kernel_size,
padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout,
)
elif args.encoder_conv_type == "dynamic":
self.conv = DynamicConv(
self.conv_dim,
kernel_size,
padding_l=padding_l,
weight_softmax=args.weight_softmax,
num_heads=args.encoder_attention_heads,
weight_dropout=args.weight_dropout,
)
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.relu_dropout_module = FairseqDropout(
args.relu_dropout, module_name=self.__class__.__name__
)
self.input_dropout_module = FairseqDropout(
args.input_dropout, module_name=self.__class__.__name__
)
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norm1 = LayerNorm(self.embed_dim)
self.layer_norm2 = LayerNorm(self.embed_dim)
def forward(self, x, encoder_padding_mask: Optional[Tensor] = None) -> Tensor:
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(batch, src_len, embed_dim)`
"""
residual = x
normalize = self.maybe_layer_norm(before=True)
if normalize:
x = self.layer_norm1(x)
x = self.input_dropout_module(x)
x = self.linear1(x)
if self.act is not None:
x = self.act(x)
if encoder_padding_mask is not None:
x = x.masked_fill(encoder_padding_mask.transpose(0, 1).unsqueeze(2), 0)
x = self.conv(x)
x = self.linear2(x)
x = self.dropout_module(x)
x = residual + x
normalize = self.maybe_layer_norm(after=True)
if normalize:
x = self.layer_norm1(x)
residual = x
normalize = self.maybe_layer_norm(before=True)
if normalize:
x = self.layer_norm2(x)
x = F.relu(self.fc1(x))
x = self.relu_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
normalize = self.maybe_layer_norm(after=True)
if normalize:
x = self.layer_norm2(x)
return x
def maybe_layer_norm(self, before: bool = False, after: bool = False):
assert before ^ after, "Incorrect arguments"
return after ^ self.normalize_before
def extra_repr(self):
return (
"dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format(
self.dropout_module.p,
self.relu_dropout_module.p,
self.input_dropout_module.p,
self.normalize_before,
)
)
class LightConvDecoderLayer(nn.Module):
"""Decoder layer block.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
kernel_size: kernel size of the convolution
"""
def __init__(self, args, no_encoder_attn=False, kernel_size=0, dictionary=None):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.conv_dim = args.decoder_conv_dim
if args.decoder_glu:
self.linear1 = Linear(self.embed_dim, 2 * self.conv_dim)
self.act = nn.GLU()
else:
self.linear1 = Linear(self.embed_dim, self.conv_dim)
self.act = None
if args.decoder_conv_type == "lightweight":
self.conv = LightweightConv(
self.conv_dim,
kernel_size,
padding_l=kernel_size - 1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout,
)
elif args.decoder_conv_type == "dynamic":
self.conv = DynamicConv(
self.conv_dim,
kernel_size,
padding_l=kernel_size - 1,
weight_softmax=args.weight_softmax,
num_heads=args.decoder_attention_heads,
weight_dropout=args.weight_dropout,
)
else:
raise NotImplementedError
self.linear2 = Linear(self.conv_dim, self.embed_dim)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.relu_dropout_module = FairseqDropout(
args.relu_dropout, module_name=self.__class__.__name__
)
self.input_dropout_module = FairseqDropout(
args.input_dropout, module_name=self.__class__.__name__
)
self.normalize_before = args.decoder_normalize_before
self.conv_layer_norm = LayerNorm(self.embed_dim)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim,
args.decoder_attention_heads,
dropout=args.attention_dropout,
encoder_decoder_attention=True,
dictionary=dictionary,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True
def forward(
self,
x: Tensor,
encoder_out: Optional[Tensor],
encoder_padding_mask: Optional[Tensor],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
prev_conv_state: Optional[Tensor] = None,
prev_attn_state: Optional[Tuple[Tensor, Tensor]] = None,
conv_mask: Optional[Tensor] = None,
conv_padding_mask: Optional[Tensor] = None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(batch, src_len, embed_dim)`
"""
residual = x
normalize = self.maybe_layer_norm(before=True)
if normalize:
x = self.conv_layer_norm(x)
if prev_conv_state is not None:
self.conv._set_input_buffer(incremental_state, prev_conv_state)
x = self.input_dropout_module(x)
x = self.linear1(x)
if self.act is not None:
x = self.act(x)
x = self.conv(x, incremental_state=incremental_state)
x = self.linear2(x)
x = self.dropout_module(x)
x = residual + x
normalize = self.maybe_layer_norm(after=True)
if normalize:
x = self.conv_layer_norm(x)
attn: Optional[Tensor] = None
if self.encoder_attn is not None:
residual = x
normalize = self.maybe_layer_norm(before=True)
if normalize:
x = self.encoder_attn_layer_norm(x)
if prev_attn_state is not None:
saved_state: Dict[str, Optional[Tensor]] = {
"prev_key": prev_attn_state[0],
"prev_value": prev_attn_state[1],
}
self.encoder_attn._set_input_buffer(incremental_state, saved_state)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = self.dropout_module(x)
x = residual + x
normalize = self.maybe_layer_norm(after=True)
if normalize:
x = self.encoder_attn_layer_norm(x)
residual = x
normalize = self.maybe_layer_norm(before=True)
if normalize:
x = self.final_layer_norm(x)
x = F.relu(self.fc1(x))
x = self.relu_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = residual + x
normalize = self.maybe_layer_norm(after=True)
if normalize:
x = self.final_layer_norm(x)
return x, attn
def maybe_layer_norm(self, before: bool = False, after: bool = False):
assert before ^ after, "Incorrect usage"
return after ^ self.normalize_before
def make_generation_fast_(self, need_attn: bool = False, **kwargs):
self.need_attn = need_attn
def extra_repr(self):
return (
"dropout={}, relu_dropout={}, input_dropout={}, normalize_before={}".format(
self.dropout_module.p,
self.relu_dropout_module.p,
self.input_dropout_module.p,
self.normalize_before,
)
)
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
@register_model_architecture("lightconv", "lightconv")
def base_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 7)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
args.relu_dropout = getattr(args, "relu_dropout", 0.0)
args.dropout = getattr(args, "dropout", 0.1)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.encoder_conv_dim = getattr(args, "encoder_conv_dim", args.encoder_embed_dim)
args.decoder_conv_dim = getattr(args, "decoder_conv_dim", args.decoder_embed_dim)
args.encoder_kernel_size_list = getattr(
args, "encoder_kernel_size_list", [3, 7, 15, 31, 31, 31, 31]
)
args.decoder_kernel_size_list = getattr(
args, "decoder_kernel_size_list", [3, 7, 15, 31, 31, 31]
)
if len(args.encoder_kernel_size_list) == 1:
args.encoder_kernel_size_list = (
args.encoder_kernel_size_list * args.encoder_layers
)
if len(args.decoder_kernel_size_list) == 1:
args.decoder_kernel_size_list = (
args.decoder_kernel_size_list * args.decoder_layers
)
assert (
len(args.encoder_kernel_size_list) == args.encoder_layers
), "encoder_kernel_size_list doesn't match encoder_layers"
assert (
len(args.decoder_kernel_size_list) == args.decoder_layers
), "decoder_kernel_size_list doesn't match decoder_layers"
args.encoder_glu = getattr(args, "encoder_glu", True)
args.decoder_glu = getattr(args, "decoder_glu", True)
args.input_dropout = getattr(args, "input_dropout", 0.1)
args.weight_dropout = getattr(args, "weight_dropout", args.attention_dropout)
@register_model_architecture("lightconv", "lightconv_iwslt_de_en")
def lightconv_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
args.encoder_layers = getattr(args, "encoder_layers", 7)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.weight_dropout = getattr(args, "weight_dropout", 0.1)
args.encoder_glu = getattr(args, "encoder_glu", False)
args.decoder_glu = getattr(args, "decoder_glu", False)
args.input_dropout = getattr(args, "input_dropout", 0.0)
base_architecture(args)
@register_model_architecture("lightconv", "lightconv_wmt_en_de")
def lightconv_wmt_en_de(args):
base_architecture(args)
@register_model_architecture("lightconv", "lightconv_wmt_en_de_big")
def lightconv_wmt_en_de_big(args):
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
args.dropout = getattr(args, "dropout", 0.3)
base_architecture(args)
@register_model_architecture("lightconv", "lightconv_wmt_en_fr_big")
def lightconv_wmt_en_fr_big(args):
args.dropout = getattr(args, "dropout", 0.1)
lightconv_wmt_en_de_big(args)
@register_model_architecture("lightconv", "lightconv_wmt_zh_en_big")
def lightconv_wmt_zh_en_big(args):
args.dropout = getattr(args, "dropout", 0.2)
args.attention_dropout = getattr(args, "attention_dropout", 0.2)
args.weight_dropout = getattr(args, "weight_dropout", 0.2)
lightconv_wmt_en_de_big(args)