|
from transformers import AutoConfig, AutoModel |
|
import torch |
|
|
|
from .modeling_chada_vit import ChAdaViTModel |
|
from .config_chada_vit import ChAdaViTConfig |
|
|
|
AutoConfig.register("chadavit", ChAdaViTConfig) |
|
AutoModel.register(ChAdaViTConfig, ChAdaViTModel) |
|
|
|
ChAdaViTConfig.register_for_auto_class() |
|
ChAdaViTModel.register_for_auto_class("AutoModel") |
|
|
|
config = ChAdaViTConfig() |
|
|
|
|
|
def adjust_keys(state_dict): |
|
new_state_dict = {} |
|
for key in state_dict: |
|
new_key = key |
|
if "encoder" in key: |
|
new_key = new_key.replace("encoder", "backbone") |
|
if "backbone" in key: |
|
new_key = new_key.replace("backbone.", "") |
|
new_state_dict[new_key] = state_dict[key] |
|
return new_state_dict |
|
|
|
|
|
|
|
model = AutoModel.from_config(config) |
|
|
|
|
|
ckpt_path = "CKPT_PTH.ckpt" |
|
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
|
|
|
|
|
adjusted_state_dict = adjust_keys(state_dict) |
|
model.load_state_dict(adjusted_state_dict, strict=False) |
|
|
|
model.push_to_hub("chadavit16-moyen") |
|
|