Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,58 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
pipeline_tag: image-classification
|
| 4 |
---
|
| 5 |
+
|
| 6 |
+
Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo.
|
| 7 |
+
|
| 8 |
+
Original weights from https://github.com/google-research/vision_transformer: This weight is based on the
|
| 9 |
+
[Original ViT_S/32 pretrained on imagenet21k](https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz)
|
| 10 |
+
|
| 11 |
+
Weights converted to PyTorch for Kornia ViT implementation (by [@gau-nernst](https://github.com/gau-nernst) in [kornia/kornia#2786](https://github.com/kornia/kornia/pull/2786#discussion_r1482339811))
|
| 12 |
+
<details>
|
| 13 |
+
|
| 14 |
+
<summary>Convert jax checkpoint function</summary>
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]):
|
| 18 |
+
|
| 19 |
+
def get_weight(key: str) -> torch.Tensor:
|
| 20 |
+
return torch.from_numpy(np_state_dict[key])
|
| 21 |
+
|
| 22 |
+
state_dict = dict()
|
| 23 |
+
state_dict["patch_embedding.cls_token"] = get_weight("cls")
|
| 24 |
+
state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1) # conv »
|
| 25 |
+
state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias")
|
| 26 |
+
state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0)
|
| 27 |
+
|
| 28 |
+
# for i, block in enumerate(self.encoder.blocks):
|
| 29 |
+
for i in range(100):
|
| 30 |
+
prefix1 = f"encoder.blocks.{i}"
|
| 31 |
+
prefix2 = f"Transformer/encoderblock_{i}"
|
| 32 |
+
|
| 33 |
+
if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict:
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale")
|
| 37 |
+
state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias")
|
| 38 |
+
|
| 39 |
+
mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1"
|
| 40 |
+
qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]]
|
| 41 |
+
qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]]
|
| 42 |
+
state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T
|
| 43 |
+
state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten()
|
| 44 |
+
state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1»
|
| 45 |
+
state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias")
|
| 46 |
+
|
| 47 |
+
state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale")
|
| 48 |
+
state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias")
|
| 49 |
+
state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T
|
| 50 |
+
state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias")
|
| 51 |
+
state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T
|
| 52 |
+
state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias")
|
| 53 |
+
|
| 54 |
+
state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale")
|
| 55 |
+
state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias")
|
| 56 |
+
return state_dict
|
| 57 |
+
```
|
| 58 |
+
</details>
|