sapthesh commited on
Commit
2ce40b4
·
verified ·
1 Parent(s): 5518c03

Create custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model.py +20 -0
custom_model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig
4
+
5
+ # Define a custom model class
6
+ class CustomModel(PreTrainedModel):
7
+ config_class = None # Set this to the custom configuration class if available
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ # Implement your model architecture here
12
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
13
+
14
+ @classmethod
15
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
16
+ config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
17
+ model = cls(config)
18
+ # Load the weights from the pretrained model
19
+ model.load_state_dict(torch.load(pretrained_model_name_or_path))
20
+ return model