johnnv's picture
Update README.md
6116b40 verified
metadata
license: apache-2.0
pipeline_tag: image-classification

Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo.

Using it with kornia:

from kornia.contrib import VisionTransformer

vit_model = VisionTransformer.from_config('vit_l/16', pretrained=True)
...

Original weights from AugReg as recommended by google research vision transformer repo: This weight is based on the AugReg l ViT_L/16 pretrained on imagenet21k

Weights converted to PyTorch for Kornia ViT implementation (by @gau-nernst in kornia/kornia#2786)

Convert jax checkpoint function
def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]):
    
    def get_weight(key: str) -> torch.Tensor:
        return torch.from_numpy(np_state_dict[key])
    
    state_dict = dict()
    state_dict["patch_embedding.cls_token"] = get_weight("cls")
    state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1)  # conv »
    state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias")
    state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0)
    
    # for i, block in enumerate(self.encoder.blocks):
    for i in range(100):
        prefix1 = f"encoder.blocks.{i}"
        prefix2 = f"Transformer/encoderblock_{i}"

        if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict:
            break

        state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale")
        state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias")

        mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1"
        qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]]
        qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]]
        state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T
        state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten()
        state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1»
        state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias")

        state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale")
        state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias")
        state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T
        state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias")
        state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T
        state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias")

    state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale")
    state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias")
    return state_dict