AbLang_heavy / model.py
qilowoq's picture
Upload AbLang
165de6b
raw
history blame
2.81 kB
import torch
from transformers import PreTrainedModel
from .extra_fns import ACT2FN
from .encoderblocks import EncoderBlocks
from .config import AbLangConfig
class AbEmbeddings(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.pad_token_id = config.ptid
self.AAEmbeddings = torch.nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.pad_token_id)
self.PositionEmbeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0) # here padding_idx is always 0
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.Dropout = torch.nn.Dropout(config.hidden_dropout_prob)
def forward(self, src):
inputs_embeds = self.AAEmbeddings(src)
position_ids = self.create_position_ids_from_input_ids(src, self.pad_token_id)
position_embeddings = self.PositionEmbeddings(position_ids)
embeddings = inputs_embeds + position_embeddings
return self.Dropout(self.LayerNorm(embeddings))
def create_position_ids_from_input_ids(self, input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Padding idx will get position 0, which will be ignored later on.
"""
mask = input_ids.ne(padding_idx).int()
return torch.cumsum(mask, dim=1).long() * mask
class AbLang(PreTrainedModel):
config_class = AbLangConfig
def __init__(self, config):
super().__init__(config)
self.AbEmbeddings = AbEmbeddings(config)
self.EncoderBlocks = EncoderBlocks(config)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
output_attentions=None,
output_hidden_states=None,
):
src = self.AbEmbeddings(input_ids)
outputs = self.EncoderBlocks(src,
attention_mask=1-attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
return apply_cls_embeddings(attention_mask, outputs)
def apply_cls_embeddings(attention_mask, outputs):
mask = attention_mask.float()
d = {k: v for k, v in torch.nonzero(mask).cpu().numpy()} # dict of sep tokens
# make sep token invisible
for i in d:
mask[i, d[i]] = 0
mask[:, 0] = 0.0 # make cls token invisible
mask = mask.unsqueeze(-1).expand(outputs.last_hidden_state.size())
sum_embeddings = torch.sum(outputs.last_hidden_state * mask, 1)
sum_mask = torch.clamp(mask.sum(1), min=1e-9)
outputs.last_hidden_state[:, 0, :] = sum_embeddings / sum_mask
return outputs