Paolo-Fraccaro
commited on
Commit
·
b8e0a76
1
Parent(s):
c70cc56
Update Prithvi.py
Browse files- Prithvi.py +31 -3
Prithvi.py
CHANGED
@@ -15,12 +15,42 @@ import torch
|
|
15 |
import torch.nn as nn
|
16 |
|
17 |
from timm.models.vision_transformer import Block
|
18 |
-
from timm.models.layers import to_2tuple
|
19 |
|
20 |
import numpy as np
|
21 |
|
22 |
from einops import rearrange
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
25 |
"""
|
26 |
grid_size: 3d tuple of grid size: t, h, w
|
@@ -85,8 +115,6 @@ class PatchEmbed(nn.Module):
|
|
85 |
|
86 |
def forward(self, x):
|
87 |
B, C, T, H, W = x.shape
|
88 |
-
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
89 |
-
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
90 |
x = self.proj(x)
|
91 |
if self.flatten:
|
92 |
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
|
|
15 |
import torch.nn as nn
|
16 |
|
17 |
from timm.models.vision_transformer import Block
|
18 |
+
from timm.models.layers import to_2tuple
|
19 |
|
20 |
import numpy as np
|
21 |
|
22 |
from einops import rearrange
|
23 |
|
24 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
25 |
+
"""
|
26 |
+
embed_dim: output dimension for each position
|
27 |
+
pos: a list of positions to be encoded: size (M,)
|
28 |
+
out: (M, D)
|
29 |
+
"""
|
30 |
+
assert embed_dim % 2 == 0
|
31 |
+
omega = np.arange(embed_dim // 2, dtype=np.float32)
|
32 |
+
omega /= embed_dim / 2.
|
33 |
+
omega = 1. / 10000**omega # (D/2,)
|
34 |
+
|
35 |
+
pos = pos.reshape(-1) # (M,)
|
36 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
37 |
+
|
38 |
+
emb_sin = np.sin(out) # (M, D/2)
|
39 |
+
emb_cos = np.cos(out) # (M, D/2)
|
40 |
+
|
41 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
42 |
+
return emb
|
43 |
+
|
44 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
45 |
+
assert embed_dim % 2 == 0
|
46 |
+
|
47 |
+
# use half of dimensions to encode grid_h
|
48 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
49 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
50 |
+
|
51 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
52 |
+
return emb
|
53 |
+
|
54 |
def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
55 |
"""
|
56 |
grid_size: 3d tuple of grid size: t, h, w
|
|
|
115 |
|
116 |
def forward(self, x):
|
117 |
B, C, T, H, W = x.shape
|
|
|
|
|
118 |
x = self.proj(x)
|
119 |
if self.flatten:
|
120 |
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|