from transformers import AutoConfig, AutoModel import torch from .modeling_chada_vit import ChAdaViTModel from .config_chada_vit import ChAdaViTConfig AutoConfig.register("chadavit", ChAdaViTConfig) AutoModel.register(ChAdaViTConfig, ChAdaViTModel) ChAdaViTConfig.register_for_auto_class() ChAdaViTModel.register_for_auto_class("AutoModel") config = ChAdaViTConfig() def adjust_keys(state_dict): new_state_dict = {} for key in state_dict: new_key = key if "encoder" in key: new_key = new_key.replace("encoder", "backbone") if "backbone" in key: new_key = new_key.replace("backbone.", "") new_state_dict[new_key] = state_dict[key] return new_state_dict # Initialize model model = AutoModel.from_config(config) # Load state dictionary ckpt_path = "CKPT_PTH.ckpt" state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] # Adjust state dictionary keys and load it into the model adjusted_state_dict = adjust_keys(state_dict) model.load_state_dict(adjusted_state_dict, strict=False) model.push_to_hub("chadavit16-moyen")