nicoboou commited on
Commit
91a4112
1 Parent(s): 9390402

Create upload_model.py

Browse files
Files changed (1) hide show
  1. upload_model.py +39 -0
upload_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+ import torch
3
+
4
+ from .modeling_chada_vit import ChAdaViTModel
5
+ from .config_chada_vit import ChAdaViTConfig
6
+
7
+ AutoConfig.register("chadavit", ChAdaViTConfig)
8
+ AutoModel.register(ChAdaViTConfig, ChAdaViTModel)
9
+
10
+ ChAdaViTConfig.register_for_auto_class()
11
+ ChAdaViTModel.register_for_auto_class("AutoModel")
12
+
13
+ config = ChAdaViTConfig()
14
+
15
+
16
+ def adjust_keys(state_dict):
17
+ new_state_dict = {}
18
+ for key in state_dict:
19
+ new_key = key
20
+ if "encoder" in key:
21
+ new_key = new_key.replace("encoder", "backbone")
22
+ if "backbone" in key:
23
+ new_key = new_key.replace("backbone.", "")
24
+ new_state_dict[new_key] = state_dict[key]
25
+ return new_state_dict
26
+
27
+
28
+ # Initialize model
29
+ model = AutoModel.from_config(config)
30
+
31
+ # Load state dictionary
32
+ ckpt_path = "CKPT_PTH.ckpt"
33
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
34
+
35
+ # Adjust state dictionary keys and load it into the model
36
+ adjusted_state_dict = adjust_keys(state_dict)
37
+ model.load_state_dict(adjusted_state_dict, strict=False)
38
+
39
+ model.push_to_hub("chadavit16-moyen")