Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig | |
# Define a custom model class | |
class CustomModel(PreTrainedModel): | |
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class | |
def __init__(self, config): | |
super().__init__(config) | |
# Implement your model architecture here | |
self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
try: | |
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) | |
model = cls(config) | |
# Load the weights from the pretrained model | |
model.load_state_dict(torch.load(pretrained_model_name_or_path)) | |
return model | |
except FileNotFoundError: | |
print(f"Failed to load model from {pretrained_model_name_or_path}. Please check the model name and path.") | |
return None |