import torch | |
from transformers import PreTrainedModel | |
from .configuration_resnet import ResnetConfig | |
class ResnetModel(PreTrainedModel): | |
config_class = ResnetConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = torch.nn.Linear(5, 10) | |
def forward(self, tensor): | |
return self.model.forward_features(tensor) | |