herwoww's picture
first upload
1547a56
raw
history blame
15.1 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
import contextlib
from fairseq import utils
from fairseq.models import (
FairseqEncoder,
)
from fairseq.modules import (
FairseqDropout,
LayerNorm,
TransformerEncoderLayer,
)
from torch import Tensor
from .transformer_layer import TransformerSentenceEncoderLayer
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
class RelativePositionalEncoding(torch.nn.Module):
def __init__(self, d_model, maxlen=1000, embed_v=False):
super(RelativePositionalEncoding, self).__init__()
self.d_model = d_model
self.maxlen = maxlen
self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
if embed_v:
self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
self.embed_v = embed_v
def forward(self, pos_seq):
pos_seq[pos_seq < -self.maxlen] = -self.maxlen
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
pos_seq = pos_seq + self.maxlen
if self.embed_v:
return self.pe_k(pos_seq), self.pe_v(pos_seq)
else:
return self.pe_k(pos_seq), None
class TransformerEncoder(FairseqEncoder):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def __init__(self, args, tgt_dict=None, embed_tokens=None):
self.args = args
super().__init__(None)
self.register_buffer("version", torch.Tensor([3]))
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.encoder_layerdrop = args.encoder_layerdrop
self.freeze_encoder_updates = args.freeze_encoder_updates
if args.no_freeze_encoder_layer is not None:
self.no_freeze_encoder_layer = eval(args.no_freeze_encoder_layer)
else:
self.no_freeze_encoder_layer = None
self.num_updates = 0
export = getattr(args, "export", False)
self.layers = nn.ModuleList([])
self.layers.extend(
[self.build_encoder_layer(args) for i in range(args.encoder_layers)]
)
self.num_layers = len(self.layers)
self.use_sent_enc_layer = args.use_sent_enc_layer
self.unb_enc_layer = getattr(args, "unb_enc_layer", -1)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(args.encoder_embed_dim, eps=args.layer_norm_eps, export=export)
if args.share_ctc_embed and embed_tokens is not None:
self.proj = nn.Linear(
embed_tokens.weight.shape[1],
embed_tokens.weight.shape[0],
bias=False,
)
self.proj.weight = embed_tokens.weight
elif tgt_dict is not None:
self.proj = Linear(args.encoder_embed_dim, len(tgt_dict))
else:
self.proj = None
if args.relative_position_embedding:
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.encoder_max_relative_position)
def build_encoder_layer(self, args):
if args.use_sent_enc_layer:
layer = TransformerSentenceEncoderLayer(
embedding_dim=args.encoder_embed_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
has_relative_attention_bias=args.relative_position_embedding,
)
else:
layer = TransformerEncoderLayer(args)
return layer
def forward(
self,
encoder_in,
encoder_padding_mask,
return_all_hiddens: bool = False,
tgt_layer=None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if self.no_freeze_encoder_layer is None:
ft = self.freeze_encoder_updates <= self.num_updates
else:
ft = True
with torch.no_grad() if not ft else contextlib.ExitStack():
encoder_out = self.forward_scriptable(
encoder_in, encoder_padding_mask, return_all_hiddens, tgt_layer=tgt_layer,
)
# CTC and bert
if self.proj:
x_for_ctc = self.proj(self.dropout_module(encoder_out["encoder_out"][0]))
else:
x_for_ctc = None
encoder_out["encoder_out_for_ctc"] = [x_for_ctc] # T x B x C
return encoder_out
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def forward_scriptable(
self,
encoder_in,
encoder_padding_mask,
return_all_hiddens: bool = False,
tgt_layer=None,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if self.no_freeze_encoder_layer is not None:
ft = self.freeze_encoder_updates <= self.num_updates
else:
ft = True
with torch.no_grad() if not ft else contextlib.ExitStack():
# compute padding mask
if not self.use_sent_enc_layer:
has_pads = encoder_in.device.type == "xla" or encoder_padding_mask.any()
if not self.layer_norm_first:
encoder_in = self.layer_norm(encoder_in)
encoder_in = self.dropout_module(encoder_in)
# B x T x C -> T x B x C
x = encoder_in.transpose(0, 1)
encoder_states = []
if return_all_hiddens:
encoder_states.append(x)
## relative position embedding
if self.args.relative_position_embedding:
x_len = x.shape[0]
pos_seq = torch.arange(0, x_len).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, pos_v = self.pos_emb(pos_seq)
else:
pos_k = None
# encoder layers
r = None
d = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
with torch.no_grad() if (not ft) and i not in self.no_freeze_encoder_layer else contextlib.ExitStack():
if not self.training or (dropout_probability > self.encoder_layerdrop) or i == self.unb_enc_layer:
if self.use_sent_enc_layer:
x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, self_attn_mask=None, need_weights=False, pos_bias=pos_k)
# x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, need_weights=False, pos_bias=pos_k)
else:
x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, attn_mask=None)
# x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None)
if i == self.unb_enc_layer:
d = x
if i == tgt_layer:
r = x
break
if return_all_hiddens:
assert encoder_states is not None
encoder_states.append(x)
with torch.no_grad() if not ft else contextlib.ExitStack():
# Finally T x B x C
if self.layer_norm_first:
x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1)
if r is not None:
x = r
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
return {
"encoder_out": [x], # T x B x C
"encoder_padding_mask": [encoder_padding_mask], # B x T
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": [],
"decoder_input": [d],
}
@torch.jit.export
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if len(encoder_out["encoder_out"]) == 0:
new_encoder_out = []
else:
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)]
if len(encoder_out["encoder_out_for_ctc"]) == 0:
new_x_for_ctc = []
else:
new_x_for_ctc = [encoder_out["encoder_out_for_ctc"][0].index_select(1, new_order)]
if len(encoder_out["encoder_padding_mask"]) == 0:
new_encoder_padding_mask = []
else:
new_encoder_padding_mask = [
encoder_out["encoder_padding_mask"][0].index_select(0, new_order)
]
if len(encoder_out["src_tokens"]) == 0:
src_tokens = []
else:
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)]
if len(encoder_out["decoder_input"]) == 0 or encoder_out["decoder_input"][0] is None:
new_decoder_input = []
else:
new_decoder_input = [
encoder_out["decoder_input"][0].index_select(0, new_order)
]
encoder_states = encoder_out["encoder_states"]
if len(encoder_states) > 0:
for idx, state in enumerate(encoder_states):
encoder_states[idx] = state.index_select(1, new_order)
return {
"encoder_out": new_encoder_out, # T x B x C
"encoder_padding_mask": new_encoder_padding_mask, # B x T
"encoder_states": encoder_states, # List[T x B x C]
"src_tokens": src_tokens, # B x T
"encoder_out_for_ctc": new_x_for_ctc, # T x B x C
"decoder_input": new_decoder_input,
}
# def max_positions(self):
# """Maximum input length supported by the encoder."""
# return self.max_source_positions
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
# if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
# weights_key = "{}.embed_positions.weights".format(name)
# if weights_key in state_dict:
# print("deleting {0}".format(weights_key))
# del state_dict[weights_key]
# state_dict[
# "{}.embed_positions._float_tensor".format(name)
# ] = torch.FloatTensor(1)
for i in range(self.num_layers):
# update layer norms
if not isinstance(self.layers[i], TransformerSentenceEncoderLayer):
self.layers[i].upgrade_state_dict_named(
state_dict, "{}.layers.{}".format(name, i)
)
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
super().set_num_updates(num_updates)
self.num_updates = num_updates