Spaces:
Runtime error
Runtime error
Update custom_model.py
Browse files- custom_model.py +8 -6
custom_model.py
CHANGED
@@ -1,8 +1,7 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
from transformers import PreTrainedModel,
|
4 |
|
5 |
-
# Define a custom model class
|
6 |
class CustomModel(PreTrainedModel):
|
7 |
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
|
8 |
|
@@ -14,11 +13,14 @@ class CustomModel(PreTrainedModel):
|
|
14 |
@classmethod
|
15 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
16 |
try:
|
|
|
17 |
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
|
|
18 |
model = cls(config)
|
19 |
-
# Load the weights
|
20 |
-
|
|
|
21 |
return model
|
22 |
-
except
|
23 |
-
print(f"Failed to load model from {pretrained_model_name_or_path}.
|
24 |
return None
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
from transformers import PreTrainedModel, AutoConfig
|
4 |
|
|
|
5 |
class CustomModel(PreTrainedModel):
|
6 |
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
|
7 |
|
|
|
13 |
@classmethod
|
14 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
15 |
try:
|
16 |
+
# Load the configuration
|
17 |
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
18 |
+
# Initialize the model with the configuration
|
19 |
model = cls(config)
|
20 |
+
# Load the model weights using the transformers library
|
21 |
+
state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location="cpu")
|
22 |
+
model.load_state_dict(state_dict)
|
23 |
return model
|
24 |
+
except Exception as e:
|
25 |
+
print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
|
26 |
return None
|