Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
# Github source: https://github.com/mbzuai-nlp/ArTST | |
# Based on speecht5, fairseq and espnet code bases | |
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
# -------------------------------------------------------- | |
from typing import Dict, List | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import contextlib | |
from fairseq import utils | |
from fairseq.models import ( | |
FairseqEncoder, | |
) | |
from fairseq.modules import ( | |
FairseqDropout, | |
LayerNorm, | |
TransformerEncoderLayer, | |
) | |
from torch import Tensor | |
from .transformer_layer import TransformerSentenceEncoderLayer | |
DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8) | |
def Linear(in_features, out_features, bias=True): | |
m = nn.Linear(in_features, out_features, bias) | |
nn.init.xavier_uniform_(m.weight) | |
if bias: | |
nn.init.constant_(m.bias, 0.0) | |
return m | |
class RelativePositionalEncoding(torch.nn.Module): | |
def __init__(self, d_model, maxlen=1000, embed_v=False): | |
super(RelativePositionalEncoding, self).__init__() | |
self.d_model = d_model | |
self.maxlen = maxlen | |
self.pe_k = torch.nn.Embedding(2*maxlen, d_model) | |
if embed_v: | |
self.pe_v = torch.nn.Embedding(2*maxlen, d_model) | |
self.embed_v = embed_v | |
def forward(self, pos_seq): | |
pos_seq[pos_seq < -self.maxlen] = -self.maxlen | |
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1 | |
pos_seq = pos_seq + self.maxlen | |
if self.embed_v: | |
return self.pe_k(pos_seq), self.pe_v(pos_seq) | |
else: | |
return self.pe_k(pos_seq), None | |
class TransformerEncoder(FairseqEncoder): | |
""" | |
Transformer encoder consisting of *args.encoder_layers* layers. Each layer | |
is a :class:`TransformerEncoderLayer`. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
dictionary (~fairseq.data.Dictionary): encoding dictionary | |
embed_tokens (torch.nn.Embedding): input embedding | |
""" | |
def __init__(self, args, tgt_dict=None, embed_tokens=None): | |
self.args = args | |
super().__init__(None) | |
self.register_buffer("version", torch.Tensor([3])) | |
self.dropout_module = FairseqDropout( | |
args.dropout, module_name=self.__class__.__name__ | |
) | |
self.encoder_layerdrop = args.encoder_layerdrop | |
self.freeze_encoder_updates = args.freeze_encoder_updates | |
if args.no_freeze_encoder_layer is not None: | |
self.no_freeze_encoder_layer = eval(args.no_freeze_encoder_layer) | |
else: | |
self.no_freeze_encoder_layer = None | |
self.num_updates = 0 | |
export = getattr(args, "export", False) | |
self.layers = nn.ModuleList([]) | |
self.layers.extend( | |
[self.build_encoder_layer(args) for i in range(args.encoder_layers)] | |
) | |
self.num_layers = len(self.layers) | |
self.use_sent_enc_layer = args.use_sent_enc_layer | |
self.unb_enc_layer = getattr(args, "unb_enc_layer", -1) | |
self.layer_norm_first = args.layer_norm_first | |
self.layer_norm = LayerNorm(args.encoder_embed_dim, eps=args.layer_norm_eps, export=export) | |
if args.share_ctc_embed and embed_tokens is not None: | |
self.proj = nn.Linear( | |
embed_tokens.weight.shape[1], | |
embed_tokens.weight.shape[0], | |
bias=False, | |
) | |
self.proj.weight = embed_tokens.weight | |
elif tgt_dict is not None: | |
self.proj = Linear(args.encoder_embed_dim, len(tgt_dict)) | |
else: | |
self.proj = None | |
if args.relative_position_embedding: | |
self.pos_emb = RelativePositionalEncoding(args.encoder_embed_dim//args.encoder_attention_heads, args.encoder_max_relative_position) | |
def build_encoder_layer(self, args): | |
if args.use_sent_enc_layer: | |
layer = TransformerSentenceEncoderLayer( | |
embedding_dim=args.encoder_embed_dim, | |
ffn_embedding_dim=args.encoder_ffn_embed_dim, | |
num_attention_heads=args.encoder_attention_heads, | |
dropout=args.dropout, | |
attention_dropout=args.attention_dropout, | |
activation_dropout=args.activation_dropout, | |
activation_fn=args.activation_fn, | |
layer_norm_first=args.layer_norm_first, | |
has_relative_attention_bias=args.relative_position_embedding, | |
) | |
else: | |
layer = TransformerEncoderLayer(args) | |
return layer | |
def forward( | |
self, | |
encoder_in, | |
encoder_padding_mask, | |
return_all_hiddens: bool = False, | |
tgt_layer=None, | |
): | |
""" | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (torch.LongTensor): lengths of each source sentence of | |
shape `(batch)` | |
return_all_hiddens (bool, optional): also return all of the | |
intermediate hidden states (default: False). | |
token_embeddings (torch.Tensor, optional): precomputed embeddings | |
default `None` will recompute embeddings | |
Returns: | |
dict: | |
- **encoder_out** (Tensor): the last encoder layer's output of | |
shape `(src_len, batch, embed_dim)` | |
- **encoder_padding_mask** (ByteTensor): the positions of | |
padding elements of shape `(batch, src_len)` | |
- **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
of shape `(batch, src_len, embed_dim)` | |
- **encoder_states** (List[Tensor]): all intermediate | |
hidden states of shape `(src_len, batch, embed_dim)`. | |
Only populated if *return_all_hiddens* is True. | |
""" | |
if self.no_freeze_encoder_layer is None: | |
ft = self.freeze_encoder_updates <= self.num_updates | |
else: | |
ft = True | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
encoder_out = self.forward_scriptable( | |
encoder_in, encoder_padding_mask, return_all_hiddens, tgt_layer=tgt_layer, | |
) | |
# CTC and bert | |
if self.proj: | |
x_for_ctc = self.proj(self.dropout_module(encoder_out["encoder_out"][0])) | |
else: | |
x_for_ctc = None | |
encoder_out["encoder_out_for_ctc"] = [x_for_ctc] # T x B x C | |
return encoder_out | |
# TorchScript doesn't support super() method so that the scriptable Subclass | |
# can't access the base class model in Torchscript. | |
# Current workaround is to add a helper function with different name and | |
# call the helper function from scriptable Subclass. | |
def forward_scriptable( | |
self, | |
encoder_in, | |
encoder_padding_mask, | |
return_all_hiddens: bool = False, | |
tgt_layer=None, | |
): | |
""" | |
Args: | |
src_tokens (LongTensor): tokens in the source language of shape | |
`(batch, src_len)` | |
src_lengths (torch.LongTensor): lengths of each source sentence of | |
shape `(batch)` | |
return_all_hiddens (bool, optional): also return all of the | |
intermediate hidden states (default: False). | |
token_embeddings (torch.Tensor, optional): precomputed embeddings | |
default `None` will recompute embeddings | |
Returns: | |
dict: | |
- **encoder_out** (Tensor): the last encoder layer's output of | |
shape `(src_len, batch, embed_dim)` | |
- **encoder_padding_mask** (ByteTensor): the positions of | |
padding elements of shape `(batch, src_len)` | |
- **encoder_embedding** (Tensor): the (scaled) embedding lookup | |
of shape `(batch, src_len, embed_dim)` | |
- **encoder_states** (List[Tensor]): all intermediate | |
hidden states of shape `(src_len, batch, embed_dim)`. | |
Only populated if *return_all_hiddens* is True. | |
""" | |
if self.no_freeze_encoder_layer is not None: | |
ft = self.freeze_encoder_updates <= self.num_updates | |
else: | |
ft = True | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
# compute padding mask | |
if not self.use_sent_enc_layer: | |
has_pads = encoder_in.device.type == "xla" or encoder_padding_mask.any() | |
if not self.layer_norm_first: | |
encoder_in = self.layer_norm(encoder_in) | |
encoder_in = self.dropout_module(encoder_in) | |
# B x T x C -> T x B x C | |
x = encoder_in.transpose(0, 1) | |
encoder_states = [] | |
if return_all_hiddens: | |
encoder_states.append(x) | |
## relative position embedding | |
if self.args.relative_position_embedding: | |
x_len = x.shape[0] | |
pos_seq = torch.arange(0, x_len).long().to(x.device) | |
pos_seq = pos_seq[:, None] - pos_seq[None, :] | |
pos_k, pos_v = self.pos_emb(pos_seq) | |
else: | |
pos_k = None | |
# encoder layers | |
r = None | |
d = None | |
for i, layer in enumerate(self.layers): | |
dropout_probability = np.random.random() | |
with torch.no_grad() if (not ft) and i not in self.no_freeze_encoder_layer else contextlib.ExitStack(): | |
if not self.training or (dropout_probability > self.encoder_layerdrop) or i == self.unb_enc_layer: | |
if self.use_sent_enc_layer: | |
x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, self_attn_mask=None, need_weights=False, pos_bias=pos_k) | |
# x, _ = layer(x, self_attn_padding_mask=encoder_padding_mask, need_weights=False, pos_bias=pos_k) | |
else: | |
x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None, attn_mask=None) | |
# x = layer(x, encoder_padding_mask=encoder_padding_mask if has_pads else None) | |
if i == self.unb_enc_layer: | |
d = x | |
if i == tgt_layer: | |
r = x | |
break | |
if return_all_hiddens: | |
assert encoder_states is not None | |
encoder_states.append(x) | |
with torch.no_grad() if not ft else contextlib.ExitStack(): | |
# Finally T x B x C | |
if self.layer_norm_first: | |
x = self.layer_norm(x.transpose(0, 1)).transpose(0, 1) | |
if r is not None: | |
x = r | |
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in | |
# `forward` so we use a dictionary instead. | |
# TorchScript does not support mixed values so the values are all lists. | |
# The empty list is equivalent to None. | |
return { | |
"encoder_out": [x], # T x B x C | |
"encoder_padding_mask": [encoder_padding_mask], # B x T | |
"encoder_states": encoder_states, # List[T x B x C] | |
"src_tokens": [], | |
"decoder_input": [d], | |
} | |
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): | |
""" | |
Reorder encoder output according to *new_order*. | |
Args: | |
encoder_out: output from the ``forward()`` method | |
new_order (LongTensor): desired order | |
Returns: | |
*encoder_out* rearranged according to *new_order* | |
""" | |
if len(encoder_out["encoder_out"]) == 0: | |
new_encoder_out = [] | |
else: | |
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] | |
if len(encoder_out["encoder_out_for_ctc"]) == 0: | |
new_x_for_ctc = [] | |
else: | |
new_x_for_ctc = [encoder_out["encoder_out_for_ctc"][0].index_select(1, new_order)] | |
if len(encoder_out["encoder_padding_mask"]) == 0: | |
new_encoder_padding_mask = [] | |
else: | |
new_encoder_padding_mask = [ | |
encoder_out["encoder_padding_mask"][0].index_select(0, new_order) | |
] | |
if len(encoder_out["src_tokens"]) == 0: | |
src_tokens = [] | |
else: | |
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] | |
if len(encoder_out["decoder_input"]) == 0 or encoder_out["decoder_input"][0] is None: | |
new_decoder_input = [] | |
else: | |
new_decoder_input = [ | |
encoder_out["decoder_input"][0].index_select(0, new_order) | |
] | |
encoder_states = encoder_out["encoder_states"] | |
if len(encoder_states) > 0: | |
for idx, state in enumerate(encoder_states): | |
encoder_states[idx] = state.index_select(1, new_order) | |
return { | |
"encoder_out": new_encoder_out, # T x B x C | |
"encoder_padding_mask": new_encoder_padding_mask, # B x T | |
"encoder_states": encoder_states, # List[T x B x C] | |
"src_tokens": src_tokens, # B x T | |
"encoder_out_for_ctc": new_x_for_ctc, # T x B x C | |
"decoder_input": new_decoder_input, | |
} | |
# def max_positions(self): | |
# """Maximum input length supported by the encoder.""" | |
# return self.max_source_positions | |
def upgrade_state_dict_named(self, state_dict, name): | |
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" | |
# if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): | |
# weights_key = "{}.embed_positions.weights".format(name) | |
# if weights_key in state_dict: | |
# print("deleting {0}".format(weights_key)) | |
# del state_dict[weights_key] | |
# state_dict[ | |
# "{}.embed_positions._float_tensor".format(name) | |
# ] = torch.FloatTensor(1) | |
for i in range(self.num_layers): | |
# update layer norms | |
if not isinstance(self.layers[i], TransformerSentenceEncoderLayer): | |
self.layers[i].upgrade_state_dict_named( | |
state_dict, "{}.layers.{}".format(name, i) | |
) | |
version_key = "{}.version".format(name) | |
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: | |
# earlier checkpoints did not normalize after the stack of layers | |
self.layer_norm = None | |
self.normalize = False | |
state_dict[version_key] = torch.Tensor([1]) | |
return state_dict | |
def set_num_updates(self, num_updates): | |
"""Set the number of parameters updates.""" | |
super().set_num_updates(num_updates) | |
self.num_updates = num_updates | |