|
import torch |
|
from transformers import PreTrainedModel |
|
from .extra_fns import ACT2FN |
|
from .encoderblocks import EncoderBlocks |
|
from .config import AbLangConfig |
|
|
|
class AbEmbeddings(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.pad_token_id = config.ptid |
|
self.AAEmbeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.pad_token_id) |
|
self.PositionEmbeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) |
|
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.Dropout = torch.nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward(self, src): |
|
inputs_embeds = self.AAEmbeddings(src) |
|
position_ids = self.create_position_ids_from_input_ids(src, self.pad_token_id) |
|
position_embeddings = self.PositionEmbeddings(position_ids) |
|
embeddings = inputs_embeds + position_embeddings |
|
return self.Dropout(self.LayerNorm(embeddings)) |
|
|
|
def create_position_ids_from_input_ids(self, input_ids, padding_idx): |
|
""" |
|
Replace non-padding symbols with their position numbers. Padding idx will get position 0, which will be ignored later on. |
|
""" |
|
mask = input_ids.ne(padding_idx).int() |
|
return torch.cumsum(mask, dim=1).long() * mask |
|
|
|
|
|
class AbLang(PreTrainedModel): |
|
config_class = AbLangConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.AbEmbeddings = AbEmbeddings(config) |
|
self.EncoderBlocks = EncoderBlocks(config) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
): |
|
src = self.AbEmbeddings(input_ids) |
|
outputs = self.EncoderBlocks(src, |
|
attention_mask=1-attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states) |
|
return apply_cls_embeddings(attention_mask, outputs) |
|
|
|
def apply_cls_embeddings(attention_mask, outputs): |
|
mask = attention_mask.float() |
|
d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} |
|
|
|
for i in d: |
|
mask[i, d[i]] = 0 |
|
mask[:, 0] = 0.0 |
|
mask = mask.unsqueeze(-1).expand(outputs.last_hidden_state.size()) |
|
sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1) |
|
sum_mask = torch.clamp(mask.sum(1), min=1e-9) |
|
outputs.last_hidden_state[:, 0, :] = sum_embeddings / sum_mask |
|
return outputs |