|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import copy |
|
import logging |
|
import math |
|
from argparse import Namespace |
|
from dataclasses import dataclass, field |
|
from typing import Any, Optional |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from omegaconf import II, MISSING, open_dict |
|
|
|
from fairseq import checkpoint_utils, tasks, utils |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
from fairseq.models import ( |
|
BaseFairseqModel, |
|
FairseqEncoder, |
|
FairseqEncoderDecoderModel, |
|
FairseqIncrementalDecoder, |
|
register_model, |
|
) |
|
from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES |
|
from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer |
|
from fairseq.tasks import FairseqTask |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class HubertAsrConfig(FairseqDataclass): |
|
w2v_path: str = field(default=MISSING, metadata={"help": "path to hubert model"}) |
|
no_pretrained_weights: bool = field( |
|
default=False, |
|
metadata={"help": "if true, does not load pretrained weights"}, |
|
) |
|
dropout_input: float = field( |
|
default=0.0, |
|
metadata={"help": "dropout to apply to the input (after feat extr)"}, |
|
) |
|
final_dropout: float = field( |
|
default=0.0, |
|
metadata={"help": "dropout after transformer and before final projection"}, |
|
) |
|
dropout: float = field( |
|
default=0.0, |
|
metadata={"help": "dropout probability inside hubert model"}, |
|
) |
|
attention_dropout: float = field( |
|
default=0.0, |
|
metadata={ |
|
"help": "dropout probability for attention weights " "inside hubert model" |
|
}, |
|
) |
|
activation_dropout: float = field( |
|
default=0.0, |
|
metadata={ |
|
"help": "dropout probability after activation in FFN " "inside hubert model" |
|
}, |
|
) |
|
encoder_embed_dim: Optional[int] = field( |
|
default=768, metadata={"help": "encoder embedding dimension"} |
|
) |
|
|
|
|
|
apply_mask: bool = field( |
|
default=False, metadata={"help": "apply masking during fine-tuning"} |
|
) |
|
mask_length: int = field( |
|
default=10, metadata={"help": "repeat the mask indices multiple times"} |
|
) |
|
mask_prob: float = field( |
|
default=0.5, |
|
metadata={ |
|
"help": "probability of replacing a token with mask " |
|
"(normalized by length)" |
|
}, |
|
) |
|
mask_selection: MASKING_DISTRIBUTION_CHOICES = field( |
|
default="static", metadata={"help": "how to choose masks"} |
|
) |
|
mask_other: float = field( |
|
default=0, |
|
metadata={ |
|
"help": "secondary mask argument " |
|
"(used for more complex distributions), " |
|
"see help in compute_mask_indices" |
|
}, |
|
) |
|
no_mask_overlap: bool = field( |
|
default=False, metadata={"help": "whether to allow masks to overlap"} |
|
) |
|
|
|
|
|
mask_channel_length: int = field( |
|
default=10, |
|
metadata={"help": "length of the mask for features (channels)"}, |
|
) |
|
mask_channel_prob: float = field( |
|
default=0.0, |
|
metadata={"help": "probability of replacing a feature with 0"}, |
|
) |
|
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( |
|
default="static", |
|
metadata={"help": "how to choose mask length for channel masking"}, |
|
) |
|
mask_channel_other: float = field( |
|
default=0, |
|
metadata={ |
|
"help": "secondary mask argument " |
|
"(used for more complex distributions), " |
|
"see help in compute_mask_indices" |
|
}, |
|
) |
|
no_mask_channel_overlap: bool = field( |
|
default=False, |
|
metadata={"help": "whether to allow channel masks to overlap"}, |
|
) |
|
freeze_finetune_updates: int = field( |
|
default=0, |
|
metadata={"help": "dont finetune hubert for this many updates"}, |
|
) |
|
feature_grad_mult: float = field( |
|
default=0.0, |
|
metadata={"help": "reset feature grad mult in hubert to this"}, |
|
) |
|
layerdrop: float = field( |
|
default=0.0, |
|
metadata={"help": "probability of dropping a layer in hubert"}, |
|
) |
|
normalize: bool = II("task.normalize") |
|
data: str = II("task.data") |
|
|
|
|
|
w2v_args: Any = None |
|
|
|
|
|
@dataclass |
|
class HubertCtcConfig(HubertAsrConfig): |
|
pass |
|
|
|
|
|
@register_model("hubert_ctc", dataclass=HubertCtcConfig) |
|
class HubertCtc(BaseFairseqModel): |
|
def __init__(self, cfg: HubertCtcConfig, w2v_encoder: BaseFairseqModel): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.w2v_encoder = w2v_encoder |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
super().upgrade_state_dict_named(state_dict, name) |
|
return state_dict |
|
|
|
@classmethod |
|
def build_model(cls, cfg: HubertCtcConfig, task: FairseqTask): |
|
"""Build a new model instance.""" |
|
w2v_encoder = HubertEncoder(cfg, task) |
|
return cls(cfg, w2v_encoder) |
|
|
|
def get_normalized_probs(self, net_output, log_probs): |
|
"""Get normalized probabilities (or log probs) from a net's output.""" |
|
|
|
logits = net_output["encoder_out"] |
|
if log_probs: |
|
return utils.log_softmax(logits.float(), dim=-1) |
|
else: |
|
return utils.softmax(logits.float(), dim=-1) |
|
|
|
def get_logits(self, net_output): |
|
logits = net_output["encoder_out"] |
|
padding = net_output["encoder_padding_mask"] |
|
if padding is not None and padding.any(): |
|
padding = padding.T |
|
logits[padding][..., 0] = 0 |
|
logits[padding][..., 1:] = float("-inf") |
|
|
|
return logits |
|
|
|
def forward(self, **kwargs): |
|
x = self.w2v_encoder(**kwargs) |
|
return x |
|
|
|
|
|
@dataclass |
|
class HubertSeq2SeqConfig(HubertAsrConfig): |
|
decoder_embed_dim: int = field( |
|
default=768, metadata={"help": "decoder embedding dimension"} |
|
) |
|
decoder_ffn_embed_dim: int = field( |
|
default=3072, metadata={"help": "decoder embedding dimension for FFN"} |
|
) |
|
decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) |
|
decoder_layerdrop: float = field( |
|
default=0.0, metadata={"help": "decoder layerdrop chance"} |
|
) |
|
decoder_attention_heads: int = field( |
|
default=4, metadata={"help": "num decoder attention heads"} |
|
) |
|
decoder_learned_pos: bool = field( |
|
default=False, |
|
metadata={"help": "use learned positional embeddings in the decoder"}, |
|
) |
|
decoder_normalize_before: bool = field( |
|
default=False, metadata={"help": "apply layernorm before each decoder block"} |
|
) |
|
no_token_positional_embeddings: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "if set, disables positional embeddings (outside self attention)" |
|
}, |
|
) |
|
decoder_dropout: float = field( |
|
default=0.0, metadata={"help": "dropout probability in the decoder"} |
|
) |
|
decoder_attention_dropout: float = field( |
|
default=0.0, |
|
metadata={ |
|
"help": "dropout probability for attention weights inside the decoder" |
|
}, |
|
) |
|
decoder_activation_dropout: float = field( |
|
default=0.0, |
|
metadata={ |
|
"help": "dropout probability after activation in FFN inside the decoder" |
|
}, |
|
) |
|
max_target_positions: int = field( |
|
default=2048, metadata={"help": "max target positions"} |
|
) |
|
share_decoder_input_output_embed: bool = field( |
|
default=False, metadata={"help": "share decoder input and output embeddings"} |
|
) |
|
autoregressive: bool = II("task.autoregressive") |
|
seq2seq_path: str = field( |
|
default="", |
|
metadata={"help": "reset_dict"}, |
|
) |
|
reset_dict: bool = field( |
|
default=False, |
|
metadata={"help": "reset_dict"}, |
|
) |
|
|
|
|
|
@register_model("hubert_seq2seq", dataclass=HubertSeq2SeqConfig) |
|
class HubertSeq2SeqModel(FairseqEncoderDecoderModel): |
|
def __init__(self, encoder, decoder): |
|
super().__init__(encoder, decoder) |
|
|
|
@classmethod |
|
def build_model(cls, cfg: HubertSeq2SeqConfig, task: FairseqTask): |
|
"""Build a new model instance.""" |
|
|
|
assert ( |
|
cfg.autoregressive |
|
), "Please set task.autoregressive=true for seq2seq asr models" |
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary |
|
|
|
def build_embedding(dictionary, embed_dim): |
|
num_embeddings = len(dictionary) |
|
padding_idx = dictionary.pad() |
|
emb = Embedding(num_embeddings, embed_dim, padding_idx) |
|
return emb |
|
|
|
decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) |
|
|
|
encoder = cls.build_encoder(cfg, task) |
|
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) |
|
|
|
model = HubertSeq2SeqModel(encoder, decoder) |
|
|
|
if cfg["seq2seq_path"]: |
|
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.seq2seq_path) |
|
state = state["model"] |
|
if cfg["reset_dict"]: |
|
del state["decoder.embed_out"] |
|
del state["decoder.embed_tokens.weight"] |
|
model.load_state_dict(state, strict=False) |
|
return model |
|
|
|
@classmethod |
|
def build_encoder(cls, cfg: HubertAsrConfig, task): |
|
return HubertEncoder(cfg, task) |
|
|
|
@classmethod |
|
def build_decoder(cls, cfg: HubertSeq2SeqConfig, tgt_dict, embed_tokens): |
|
return TransformerDecoder(cfg, tgt_dict, embed_tokens) |
|
|
|
def forward(self, **kwargs): |
|
encoder_out = self.encoder(**kwargs) |
|
decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) |
|
return decoder_out |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
return state_dict |
|
|
|
def load_state_dict( |
|
self, |
|
state_dict, |
|
strict=True, |
|
model_cfg=None, |
|
args: Optional[Namespace] = None, |
|
): |
|
if model_cfg.reset_dict: |
|
logger.warn("Overriding loading strict state dict!") |
|
del state_dict["decoder.embed_out"] |
|
del state_dict["decoder.embed_tokens.weight"] |
|
return super().load_state_dict(state_dict, False, model_cfg, args) |
|
return super().load_state_dict(state_dict, strict, model_cfg, args) |
|
|
|
|
|
class HubertEncoder(FairseqEncoder): |
|
def __init__(self, cfg: HubertAsrConfig, task): |
|
self.apply_mask = cfg.apply_mask |
|
|
|
arg_overrides = { |
|
"dropout": cfg.dropout, |
|
"activation_dropout": cfg.activation_dropout, |
|
"dropout_input": cfg.dropout_input, |
|
"attention_dropout": cfg.attention_dropout, |
|
"mask_length": cfg.mask_length, |
|
"mask_prob": cfg.mask_prob, |
|
"mask_selection": cfg.mask_selection, |
|
"mask_other": cfg.mask_other, |
|
"no_mask_overlap": cfg.no_mask_overlap, |
|
"mask_channel_length": cfg.mask_channel_length, |
|
"mask_channel_prob": cfg.mask_channel_prob, |
|
"mask_channel_selection": cfg.mask_channel_selection, |
|
"mask_channel_other": cfg.mask_channel_other, |
|
"no_mask_channel_overlap": cfg.no_mask_channel_overlap, |
|
"encoder_layerdrop": cfg.layerdrop, |
|
"feature_grad_mult": cfg.feature_grad_mult, |
|
} |
|
|
|
if cfg.w2v_args is None: |
|
state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) |
|
w2v_args = state.get("cfg", None) |
|
if w2v_args is None: |
|
w2v_args = convert_namespace_to_omegaconf(state["args"]) |
|
cfg.w2v_args = w2v_args |
|
else: |
|
state = None |
|
w2v_args = cfg.w2v_args |
|
if isinstance(w2v_args, Namespace): |
|
cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) |
|
|
|
assert cfg.normalize == w2v_args.task.normalize, ( |
|
"Fine-tuning works best when data normalization is the same. " |
|
"Please check that --normalize is set or unset for " |
|
"both pre-training and here" |
|
) |
|
|
|
w2v_args.task.data = cfg.data |
|
pretrain_task = tasks.setup_task(w2v_args.task) |
|
if state is not None and "task_state" in state: |
|
|
|
pretrain_task.load_state_dict(state["task_state"]) |
|
else: |
|
pretrain_task.load_state_dict(task.state_dict()) |
|
|
|
model = pretrain_task.build_model(w2v_args.model, from_checkpoint=True) |
|
if state is not None and not cfg.no_pretrained_weights: |
|
|
|
model.load_state_dict(state["model"], strict=False) |
|
|
|
model.remove_pretraining_modules() |
|
|
|
super().__init__(pretrain_task.source_dictionary) |
|
|
|
d = w2v_args.model.encoder_embed_dim |
|
|
|
self.w2v_model = model |
|
|
|
self.final_dropout = nn.Dropout(cfg.final_dropout) |
|
self.freeze_finetune_updates = cfg.freeze_finetune_updates |
|
self.num_updates = 0 |
|
|
|
if task.target_dictionary is not None and not cfg.autoregressive: |
|
self.proj = Linear(d, len(task.target_dictionary)) |
|
elif getattr(cfg, "decoder_embed_dim", d) != d: |
|
self.proj = Linear(d, cfg.decoder_embed_dim) |
|
else: |
|
self.proj = None |
|
|
|
def set_num_updates(self, num_updates): |
|
"""Set the number of parameters updates.""" |
|
super().set_num_updates(num_updates) |
|
self.num_updates = num_updates |
|
|
|
def forward(self, source, padding_mask, tbc=True, **kwargs): |
|
|
|
w2v_args = { |
|
"source": source, |
|
"padding_mask": padding_mask, |
|
"mask": self.apply_mask and self.training, |
|
} |
|
|
|
ft = self.freeze_finetune_updates <= self.num_updates |
|
|
|
with torch.no_grad() if not ft else contextlib.ExitStack(): |
|
x, padding_mask = self.w2v_model.extract_features(**w2v_args) |
|
|
|
if tbc: |
|
|
|
x = x.transpose(0, 1) |
|
|
|
x = self.final_dropout(x) |
|
|
|
if self.proj: |
|
x = self.proj(x) |
|
|
|
return { |
|
"encoder_out": x, |
|
"encoder_padding_mask": padding_mask, |
|
"padding_mask": padding_mask, |
|
} |
|
|
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
if encoder_out["encoder_out"] is not None: |
|
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( |
|
1, new_order |
|
) |
|
if encoder_out["encoder_padding_mask"] is not None: |
|
encoder_out["encoder_padding_mask"] = encoder_out[ |
|
"encoder_padding_mask" |
|
].index_select(0, new_order) |
|
if encoder_out["padding_mask"] is not None: |
|
encoder_out["padding_mask"] = encoder_out["padding_mask"].index_select( |
|
0, new_order |
|
) |
|
return encoder_out |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the encoder.""" |
|
return None |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
return state_dict |
|
|
|
|
|
class TransformerDecoder(FairseqIncrementalDecoder): |
|
""" |
|
Transformer decoder consisting of *args.decoder_layers* layers. Each layer |
|
is a :class:`TransformerDecoderLayer`. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
dictionary (~fairseq.data.Dictionary): decoding dictionary |
|
embed_tokens (torch.nn.Embedding): output embedding |
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs |
|
(default: False). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
cfg: HubertSeq2SeqConfig, |
|
dictionary, |
|
embed_tokens, |
|
no_encoder_attn=False, |
|
): |
|
super().__init__(dictionary) |
|
|
|
self.dropout = cfg.decoder_dropout |
|
self.share_input_output_embed = cfg.share_decoder_input_output_embed |
|
|
|
input_embed_dim = embed_tokens.embedding_dim |
|
embed_dim = cfg.decoder_embed_dim |
|
self.output_embed_dim = cfg.decoder_embed_dim |
|
|
|
self.layerdrop = cfg.decoder_layerdrop |
|
|
|
self.padding_idx = embed_tokens.padding_idx |
|
self.max_target_positions = cfg.max_target_positions |
|
|
|
self.embed_tokens = embed_tokens |
|
self.embed_scale = math.sqrt(embed_dim) |
|
|
|
self.project_in_dim = ( |
|
Linear(input_embed_dim, embed_dim, bias=False) |
|
if embed_dim != input_embed_dim |
|
else None |
|
) |
|
|
|
self.embed_positions = ( |
|
PositionalEmbedding( |
|
cfg.max_target_positions, |
|
embed_dim, |
|
self.padding_idx, |
|
learned=cfg.decoder_learned_pos, |
|
) |
|
if not cfg.no_token_positional_embeddings |
|
else None |
|
) |
|
|
|
|
|
transformer_cfg = copy.deepcopy(cfg) |
|
with open_dict(transformer_cfg): |
|
transformer_cfg.dropout = transformer_cfg.decoder_dropout |
|
transformer_cfg.attention_dropout = ( |
|
transformer_cfg.decoder_attention_dropout |
|
) |
|
transformer_cfg.activation_dropout = ( |
|
transformer_cfg.decoder_activation_dropout |
|
) |
|
|
|
self.layers = nn.ModuleList([]) |
|
self.layers.extend( |
|
[ |
|
TransformerDecoderLayer(transformer_cfg, no_encoder_attn) |
|
for _ in range(transformer_cfg.decoder_layers) |
|
] |
|
) |
|
|
|
if not self.share_input_output_embed: |
|
self.embed_out = nn.Parameter( |
|
torch.Tensor(len(dictionary), self.output_embed_dim) |
|
) |
|
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim**-0.5) |
|
|
|
if transformer_cfg.decoder_normalize_before: |
|
self.layer_norm = LayerNorm(embed_dim) |
|
else: |
|
self.layer_norm = None |
|
|
|
def forward( |
|
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused |
|
): |
|
""" |
|
Args: |
|
prev_output_tokens (LongTensor): previous decoder outputs of shape |
|
`(batch, tgt_len)`, for teacher forcing |
|
encoder_out (Tensor, optional): output from the encoder, used for |
|
encoder-side attention |
|
incremental_state (dict): dictionary used for storing state during |
|
:ref:`Incremental decoding` |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's output of shape `(batch, tgt_len, vocab)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
if type(prev_output_tokens) == list: |
|
max_len = max((len(x) for x in prev_output_tokens)) |
|
tmp = torch.zeros( |
|
[len(prev_output_tokens), max_len], device=prev_output_tokens[0].device |
|
) |
|
for (i, p) in enumerate(prev_output_tokens): |
|
tmp[i, : len(p)] = p |
|
prev_output_tokens = tmp |
|
prev_output_tokens = prev_output_tokens.long() |
|
x, extra = self.extract_features( |
|
prev_output_tokens, encoder_out, incremental_state |
|
) |
|
x = self.output_layer(x) |
|
return x, extra |
|
|
|
def extract_features( |
|
self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused |
|
): |
|
""" |
|
Similar to *forward* but only return features. |
|
|
|
Returns: |
|
tuple: |
|
- the decoder's features of shape `(batch, tgt_len, embed_dim)` |
|
- a dictionary with any model-specific outputs |
|
""" |
|
|
|
|
|
positions = ( |
|
self.embed_positions( |
|
prev_output_tokens, incremental_state=incremental_state |
|
) |
|
if self.embed_positions is not None |
|
else None |
|
) |
|
|
|
if incremental_state is not None: |
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
if positions is not None: |
|
positions = positions[:, -1:] |
|
|
|
|
|
x = self.embed_scale * self.embed_tokens(prev_output_tokens) |
|
|
|
if self.project_in_dim is not None: |
|
x = self.project_in_dim(x) |
|
|
|
if positions is not None: |
|
x += positions |
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
attn = None |
|
|
|
inner_states = [x] |
|
|
|
|
|
self_attn_padding_mask = None |
|
if prev_output_tokens.eq(self.padding_idx).any(): |
|
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) |
|
for layer in self.layers: |
|
dropout_probability = np.random.random() |
|
if not self.training or (dropout_probability > self.layerdrop): |
|
x, attn, _ = layer( |
|
x, |
|
encoder_out["encoder_out"] if encoder_out is not None else None, |
|
encoder_out["padding_mask"] if encoder_out is not None else None, |
|
incremental_state, |
|
self_attn_mask=self.buffered_future_mask(x) |
|
if incremental_state is None |
|
else None, |
|
self_attn_padding_mask=self_attn_padding_mask, |
|
) |
|
inner_states.append(x) |
|
|
|
if self.layer_norm: |
|
x = self.layer_norm(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
return x, {"attn": attn, "inner_states": inner_states} |
|
|
|
def output_layer(self, features, **kwargs): |
|
"""Project features to the vocabulary size.""" |
|
|
|
if self.share_input_output_embed: |
|
return F.linear(features, self.embed_tokens.weight) |
|
else: |
|
return F.linear(features, self.embed_out) |
|
|
|
def max_positions(self): |
|
"""Maximum output length supported by the decoder.""" |
|
if self.embed_positions is None: |
|
return self.max_target_positions |
|
return min(self.max_target_positions, self.embed_positions.max_positions) |
|
|
|
def buffered_future_mask(self, tensor): |
|
dim = tensor.size(0) |
|
if ( |
|
not hasattr(self, "_future_mask") |
|
or self._future_mask is None |
|
or self._future_mask.device != tensor.device |
|
or self._future_mask.size(0) < dim |
|
): |
|
self._future_mask = torch.triu( |
|
utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 |
|
) |
|
return self._future_mask[:dim, :dim] |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
return state_dict |
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx): |
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) |
|
nn.init.constant_(m.weight[padding_idx], 0) |
|
return m |
|
|
|
|
|
def Linear(in_features, out_features, bias=True): |
|
m = nn.Linear(in_features, out_features, bias) |
|
nn.init.xavier_uniform_(m.weight) |
|
if bias: |
|
nn.init.constant_(m.bias, 0.0) |
|
return m |
|
|