File size: 841 Bytes
2ce40b4
 
 
 
 
 
09d9f2a
2ce40b4
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig

# Define a custom model class
class CustomModel(PreTrainedModel):
    config_class = AutoConfig  # Use AutoConfig to dynamically load the configuration class

    def __init__(self, config):
        super().__init__(config)
        # Implement your model architecture here
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
        model = cls(config)
        # Load the weights from the pretrained model
        model.load_state_dict(torch.load(pretrained_model_name_or_path))
        return model