File size: 1,108 Bytes
91a4112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import AutoConfig, AutoModel
import torch

from .modeling_chada_vit import ChAdaViTModel
from .config_chada_vit import ChAdaViTConfig

AutoConfig.register("chadavit", ChAdaViTConfig)
AutoModel.register(ChAdaViTConfig, ChAdaViTModel)

ChAdaViTConfig.register_for_auto_class()
ChAdaViTModel.register_for_auto_class("AutoModel")

config = ChAdaViTConfig()


def adjust_keys(state_dict):
    new_state_dict = {}
    for key in state_dict:
        new_key = key
        if "encoder" in key:
            new_key = new_key.replace("encoder", "backbone")
        if "backbone" in key:
            new_key = new_key.replace("backbone.", "")
        new_state_dict[new_key] = state_dict[key]
    return new_state_dict


# Initialize model
model = AutoModel.from_config(config)

# Load state dictionary
ckpt_path = "CKPT_PTH.ckpt"
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]

# Adjust state dictionary keys and load it into the model
adjusted_state_dict = adjust_keys(state_dict)
model.load_state_dict(adjusted_state_dict, strict=False)

model.push_to_hub("chadavit16-moyen")