sapthesh commited on
Commit
6f9d080
Β·
verified Β·
1 Parent(s): fb6e979

Delete custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model.py +0 -35
custom_model.py DELETED
@@ -1,35 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import PreTrainedModel, AutoConfig, AutoModel
4
-
5
- class CustomModel(PreTrainedModel):
6
- config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
7
-
8
- def __init__(self, config):
9
- super().__init__(config)
10
- # Implement your model architecture here
11
- self.encoder = AutoModel.from_config(config) # Load the base model
12
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
13
-
14
- def forward(self, input_ids, attention_mask=None):
15
- # Pass inputs through the encoder
16
- outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
17
- # Get the pooled output (e.g., CLS token for classification tasks)
18
- pooled_output = outputs.last_hidden_state[:, 0, :]
19
- # Pass through the classifier
20
- logits = self.classifier(pooled_output)
21
- return logits
22
-
23
- @classmethod
24
- def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
25
- try:
26
- # Load the configuration
27
- config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
28
- # Initialize the model with the configuration
29
- model = cls(config)
30
- # Optionally, you can load the state_dict here if needed
31
- # model.load_state_dict(torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")))
32
- return model
33
- except Exception as e:
34
- print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
35
- return None