Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,57 @@
|
|
1 |
---
|
2 |
license: apache-2.0
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: apache-2.0
|
3 |
+
pipeline_tag: image-classification
|
4 |
---
|
5 |
+
|
6 |
+
Pytorch weights for Kornia ViT converted from the original google JAX vision-transformer repo.
|
7 |
+
|
8 |
+
Original weights from https://github.com/google-research/vision_transformer: This weight is based on the [Original ViT_L/16 pretrained on imagenet21k](https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz)
|
9 |
+
|
10 |
+
Weights converted to PyTorch for Kornia ViT implementation (by [@gau-nernst](https://github.com/gau-nernst) in [kornia/kornia#2786](https://github.com/kornia/kornia/pull/2786#discussion_r1482339811))
|
11 |
+
<details>
|
12 |
+
|
13 |
+
<summary>Convert jax checkpoint function</summary>
|
14 |
+
|
15 |
+
```
|
16 |
+
def convert_jax_checkpoint(np_state_dict: dict[str, np.ndarray]):
|
17 |
+
|
18 |
+
def get_weight(key: str) -> torch.Tensor:
|
19 |
+
return torch.from_numpy(np_state_dict[key])
|
20 |
+
|
21 |
+
state_dict = dict()
|
22 |
+
state_dict["patch_embedding.cls_token"] = get_weight("cls")
|
23 |
+
state_dict["patch_embedding.backbone.weight"] = get_weight("embedding/kernel").permute(3, 2, 0, 1) # conv »
|
24 |
+
state_dict["patch_embedding.backbone.bias"] = get_weight("embedding/bias")
|
25 |
+
state_dict["patch_embedding.positions"] = get_weight("Transformer/posembed_input/pos_embedding").squeeze(0)
|
26 |
+
|
27 |
+
# for i, block in enumerate(self.encoder.blocks):
|
28 |
+
for i in range(100):
|
29 |
+
prefix1 = f"encoder.blocks.{i}"
|
30 |
+
prefix2 = f"Transformer/encoderblock_{i}"
|
31 |
+
|
32 |
+
if f"{prefix2}/LayerNorm_0/scale" not in np_state_dict:
|
33 |
+
break
|
34 |
+
|
35 |
+
state_dict[f"{prefix1}.0.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_0/scale")
|
36 |
+
state_dict[f"{prefix1}.0.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_0/bias")
|
37 |
+
|
38 |
+
mha_prefix = f"{prefix2}/MultiHeadDotProductAttention_1"
|
39 |
+
qkv_weight = [get_weight(f"{mha_prefix}/{x}/kernel") for x in ["query", "key", "value"]]
|
40 |
+
qkv_bias = [get_weight(f"{mha_prefix}/{x}/bias") for x in ["query", "key", "value"]]
|
41 |
+
state_dict[f"{prefix1}.0.fn.1.qkv.weight"] = torch.cat(qkv_weight, 1).flatten(1).T
|
42 |
+
state_dict[f"{prefix1}.0.fn.1.qkv.bias"] = torch.cat(qkv_bias, 0).flatten()
|
43 |
+
state_dict[f"{prefix1}.0.fn.1.projection.weight"] = get_weight(f"{mha_prefix}/out/kernel").flatten(0, 1»
|
44 |
+
state_dict[f"{prefix1}.0.fn.1.projection.bias"] = get_weight(f"{mha_prefix}/out/bias")
|
45 |
+
|
46 |
+
state_dict[f"{prefix1}.1.fn.0.weight"] = get_weight(f"{prefix2}/LayerNorm_2/scale")
|
47 |
+
state_dict[f"{prefix1}.1.fn.0.bias"] = get_weight(f"{prefix2}/LayerNorm_2/bias")
|
48 |
+
state_dict[f"{prefix1}.1.fn.1.0.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/kernel").T
|
49 |
+
state_dict[f"{prefix1}.1.fn.1.0.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_0/bias")
|
50 |
+
state_dict[f"{prefix1}.1.fn.1.3.weight"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/kernel").T
|
51 |
+
state_dict[f"{prefix1}.1.fn.1.3.bias"] = get_weight(f"{prefix2}/MlpBlock_3/Dense_1/bias")
|
52 |
+
|
53 |
+
state_dict["norm.weight"] = get_weight("Transformer/encoder_norm/scale")
|
54 |
+
state_dict["norm.bias"] = get_weight("Transformer/encoder_norm/bias")
|
55 |
+
return state_dict
|
56 |
+
```
|
57 |
+
</details>
|