Spaces:
Runtime error
Runtime error
Update custom_model.py
Browse files- custom_model.py +9 -5
custom_model.py
CHANGED
@@ -13,8 +13,12 @@ class CustomModel(PreTrainedModel):
|
|
13 |
|
14 |
@classmethod
|
15 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
@classmethod
|
15 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
16 |
+
try:
|
17 |
+
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
18 |
+
model = cls(config)
|
19 |
+
# Load the weights from the pretrained model
|
20 |
+
model.load_state_dict(torch.load(pretrained_model_name_or_path))
|
21 |
+
return model
|
22 |
+
except FileNotFoundError:
|
23 |
+
print(f"Failed to load model from {pretrained_model_name_or_path}. Please check the model name and path.")
|
24 |
+
return None
|