File size: 1,108 Bytes
91a4112 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from transformers import AutoConfig, AutoModel
import torch
from .modeling_chada_vit import ChAdaViTModel
from .config_chada_vit import ChAdaViTConfig
AutoConfig.register("chadavit", ChAdaViTConfig)
AutoModel.register(ChAdaViTConfig, ChAdaViTModel)
ChAdaViTConfig.register_for_auto_class()
ChAdaViTModel.register_for_auto_class("AutoModel")
config = ChAdaViTConfig()
def adjust_keys(state_dict):
new_state_dict = {}
for key in state_dict:
new_key = key
if "encoder" in key:
new_key = new_key.replace("encoder", "backbone")
if "backbone" in key:
new_key = new_key.replace("backbone.", "")
new_state_dict[new_key] = state_dict[key]
return new_state_dict
# Initialize model
model = AutoModel.from_config(config)
# Load state dictionary
ckpt_path = "CKPT_PTH.ckpt"
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# Adjust state dictionary keys and load it into the model
adjusted_state_dict = adjust_keys(state_dict)
model.load_state_dict(adjusted_state_dict, strict=False)
model.push_to_hub("chadavit16-moyen")
|