File size: 512 Bytes
3b413ba
 
51df33d
3b413ba
7319aeb
3b413ba
 
 
65e061d
3b413ba
 
 
eddb7d2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch import nn
from transformers import PreTrainedModel, MobileBertModel

class SimModel(MobileBertModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.encoder = MobileBertModel(config)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, input_ids, attention_mask):
        print(input_ids, attention_mask)
        return self.encoder(input_ids, attention_mask).last_hidden_state