|
import torch |
|
import torch.nn as nn |
|
from transformers import RoFormerModel, RoFormerPreTrainedModel |
|
|
|
|
|
class RoFormerForSparseEmbedding(RoFormerPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.encoder = RoFormerModel(config) |
|
self.linear_layer = nn.Linear(config.hidden_size, 1) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward(self, input_ids, attention_mask, return_sparse=False): |
|
B, L = input_ids.shape |
|
|
|
last_hidden_states = self.encoder(input_ids, attention_mask)['last_hidden_state'] |
|
token_weights = self.linear_layer(last_hidden_states).squeeze(-1) |
|
token_mask = (1 - attention_mask) * -1e4 |
|
token_mask[:, 0] = -1e4 |
|
last_ind = torch.sum(attention_mask, -1, keepdim=True) - 1 |
|
token_mask = torch.scatter(token_mask, -1, last_ind, -1e4) |
|
token_weights = token_weights + token_mask |
|
|
|
emb = torch.zeros(B, L, self.encoder.config.vocab_size, dtype=token_weights.dtype, |
|
device=token_weights.device) |
|
emb = torch.scatter(emb, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights.unsqueeze(-1)) |
|
emb = torch.max(torch.relu(emb), dim=-2).values |
|
|
|
if return_sparse: |
|
emb = emb.to_sparse() |
|
|
|
return emb |
|
|