|
import torch |
|
from transformers import SpeechEncoderDecoderModel |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.models.encoder_decoder.modeling_encoder_decoder import shift_tokens_right |
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
|
class Wav2VecGPT2Model(SpeechEncoderDecoderModel): |
|
""" |
|
Basically the same as `SpeechEncoderDecoderModel` but position embeddings (initialized with GPT2's position |
|
embeddings) are added to encoder output |
|
""" |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.encoder_outputs_pos_emb = nn.Embedding(1024, self.decoder.config.hidden_size) |
|
with torch.no_grad(): |
|
self.encoder_outputs_pos_emb.weight.copy_(self.decoder.transformer.wpe.weight) |
|
self.enc_to_dec_proj_ln = nn.LayerNorm(self.decoder.config.hidden_size, |
|
eps=self.decoder.config.layer_norm_epsilon) |
|
|
|
def __getattribute__(self, name): |
|
|
|
if name == '__class__': |
|
return SpeechEncoderDecoderModel |
|
return SpeechEncoderDecoderModel.__getattribute__(self, name) |
|
|
|
def forward( |
|
self, |
|
inputs=None, |
|
attention_mask=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
encoder_outputs=None, |
|
past_key_values=None, |
|
decoder_inputs_embeds=None, |
|
labels=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
input_values=None, |
|
input_features=None, |
|
return_dict=None, |
|
**kwargs, |
|
): |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
if encoder_outputs is None and inputs is None: |
|
if input_values is not None and input_features is not None: |
|
raise ValueError("You cannot specify both input_values and input_features at the same time") |
|
elif input_values is not None: |
|
inputs = input_values |
|
elif input_features is not None: |
|
inputs = input_features |
|
else: |
|
raise ValueError("You have to specify either input_values or input_features") |
|
|
|
encoder_outputs = self.encoder( |
|
inputs, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs_encoder, |
|
) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
|
|
if ( |
|
self.encoder_output_dim != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
|
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
|
encoder_hidden_states += self.encoder_outputs_pos_emb( |
|
torch.arange(0, encoder_hidden_states.shape[1], device=encoder_hidden_states.device) |
|
) |
|
encoder_hidden_states = self.enc_to_dec_proj_ln(encoder_hidden_states) |
|
|
|
|
|
if attention_mask is not None: |
|
encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( |
|
encoder_hidden_states.shape[1], attention_mask |
|
) |
|
else: |
|
encoder_attention_mask = None |
|
|
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): |
|
decoder_input_ids = shift_tokens_right( |
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
**kwargs_decoder, |
|
) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.view(-1)) |
|
|
|
if not return_dict: |
|
if loss is not None: |
|
return (loss,) + decoder_outputs + encoder_outputs |
|
else: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs[0], |
|
encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None), |
|
encoder_attentions=getattr(encoder_outputs, 'attentions', None), |
|
) |