File size: 1,051 Bytes
2ce40b4
 
 
 
 
 
09d9f2a
2ce40b4
 
 
 
 
 
 
 
1fea6dc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig

# Define a custom model class
class CustomModel(PreTrainedModel):
    config_class = AutoConfig  # Use AutoConfig to dynamically load the configuration class

    def __init__(self, config):
        super().__init__(config)
        # Implement your model architecture here
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        try:
            config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
            model = cls(config)
            # Load the weights from the pretrained model
            model.load_state_dict(torch.load(pretrained_model_name_or_path))
            return model
        except FileNotFoundError:
            print(f"Failed to load model from {pretrained_model_name_or_path}. Please check the model name and path.")
            return None