Spaces:
Runtime error
Runtime error
File size: 1,051 Bytes
2ce40b4 09d9f2a 2ce40b4 1fea6dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig
# Define a custom model class
class CustomModel(PreTrainedModel):
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
def __init__(self, config):
super().__init__(config)
# Implement your model architecture here
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
try:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
model = cls(config)
# Load the weights from the pretrained model
model.load_state_dict(torch.load(pretrained_model_name_or_path))
return model
except FileNotFoundError:
print(f"Failed to load model from {pretrained_model_name_or_path}. Please check the model name and path.")
return None |