import torch | |
from torch import nn | |
from transformers import BertPreTrainedModel | |
class ParagramSPModel(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def forward(self, input_ids, attention_mask): | |
print(input_ids) | |
print(attention_mask) | |
embeddings = self.word_embeddings(input_ids) | |
masked_embeddings = embeddings * attention_mask[:, :, None] | |
mean_pooled_embeddings = masked_embeddings.sum(dim=1) / attention_mask[:, :, None].sum(dim=1) | |
return (embeddings, mean_pooled_embeddings, embeddings) |