File size: 446 Bytes
3b413ba
 
 
 
 
 
 
 
 
 
 
 
 
4ffede5
52f8006
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch import nn
from transformers import MobileBert

class SimModel(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.encoder = MobileBert(config)
        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, **input_args):
        print(**input_args)
        return self.encoder(**input_args).last_hidden_state