File size: 1,366 Bytes
3b413ba 0a985cf 3b413ba 9889278 3b413ba 0b9a150 3b413ba 5268492 b81b79a 77a3be3 bcab452 b81b79a bcab452 b81b79a 5645e15 b81b79a 6dd5977 a7d623d 5268492 b81b79a 0728fa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
from torch import nn
from transformers import BertPreTrainedModel
class ParagramSPModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# Initialize weights and apply final processing
self.post_init()
def filter_input_ids(self, input_ids):
output = []
length = input_ids.shape[1]
for i in range(input_ids.shape[0]):
ids = input_ids[i]
filtered_ids = []
for j in ids:
if j > 0:
filtered_ids.append(j)
if len(filtered_ids) == 0:
filtered_ids = [0]
output.append(filtered_ids + [self.config.pad_token_id] * (length - length(filtered_ids)))
return torch.tensor(output)
def forward(self, input_ids, attention_mask):
print(input_ids)
print(attention_mask)
input_ids = self.filter_input_ids(input_ids)
attention_mask = input_ids > 0
embeddings = self.word_embeddings(input_ids)
masked_embeddings = embeddings * attention_mask[:, :, None]
mean_pooled_embeddings = masked_embeddings.sum(dim=1) / attention_mask[:, :, None].sum(dim=1)
return (embeddings, mean_pooled_embeddings, embeddings) |