sapthesh commited on
Commit
8efe266
Β·
verified Β·
1 Parent(s): 669c94a

Update custom_model.py

Browse files
Files changed (1) hide show
  1. custom_model.py +8 -6
custom_model.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import PreTrainedModel, PreTrainedTokenizer, AutoTokenizer, AutoConfig
4
 
5
- # Define a custom model class
6
  class CustomModel(PreTrainedModel):
7
  config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
8
 
@@ -14,11 +13,14 @@ class CustomModel(PreTrainedModel):
14
  @classmethod
15
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
16
  try:
 
17
  config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
 
18
  model = cls(config)
19
- # Load the weights from the pretrained model
20
- model.load_state_dict(torch.load(pretrained_model_name_or_path))
 
21
  return model
22
- except FileNotFoundError:
23
- print(f"Failed to load model from {pretrained_model_name_or_path}. Please check the model name and path.")
24
  return None
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import PreTrainedModel, AutoConfig
4
 
 
5
  class CustomModel(PreTrainedModel):
6
  config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
7
 
 
13
  @classmethod
14
  def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
15
  try:
16
+ # Load the configuration
17
  config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
18
+ # Initialize the model with the configuration
19
  model = cls(config)
20
+ # Load the model weights using the transformers library
21
+ state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location="cpu")
22
+ model.load_state_dict(state_dict)
23
  return model
24
+ except Exception as e:
25
+ print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
26
  return None