|
|
import collections.abc |
|
|
from itertools import repeat |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .droppath import DropPath |
|
|
from .swin import Mlp |
|
|
|
|
|
|
|
|
def constant_init(tensor, constant=0.0): |
|
|
nn.init.constant_(tensor, constant) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def _ntuple(n): |
|
|
def parse(x): |
|
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
|
|
return x |
|
|
return tuple(repeat(x, n)) |
|
|
|
|
|
return parse |
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_features=None, |
|
|
hidden_features=None, |
|
|
out_features=None, |
|
|
activation=F.gelu, |
|
|
drop=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = activation |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x, train: bool = True): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) if train else x |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) if train else x |
|
|
return x |
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
|
""" |
|
|
Default multihead attention |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
self.scale = self.head_dim**-0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
nn.init.xavier_uniform_(self.qkv.weight) |
|
|
nn.init.xavier_uniform_(self.proj.weight) |
|
|
|
|
|
def forward(self, x, train: bool = True): |
|
|
B, N, C = x.shape |
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B, N, 3, self.num_heads, C // self.num_heads) |
|
|
.permute(2, 0, 3, 1, 4) |
|
|
) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self.attn_drop(attn) if train else attn |
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) if train else x |
|
|
return x |
|
|
|
|
|
|
|
|
def window_partition1d(x, window_size): |
|
|
B, W, C = x.shape |
|
|
x = x.view(B, W // window_size, window_size, C) |
|
|
windows = x.view(-1, window_size, C) |
|
|
return windows |
|
|
|
|
|
|
|
|
def window_reverse1d(windows, window_size, W: int): |
|
|
B = int(windows.shape[0] / (W / window_size)) |
|
|
x = windows.view(B, W // window_size, window_size, -1) |
|
|
x = x.view(B, W, -1) |
|
|
return x |
|
|
|
|
|
|
|
|
def get_relative_position_index1d(win_w): |
|
|
|
|
|
coords = torch.stack(torch.meshgrid(torch.arange(win_w))) |
|
|
|
|
|
relative_coords = coords[:, :, None] - coords[:, None, :] |
|
|
relative_coords = relative_coords.permute(1, 2, 0) |
|
|
|
|
|
relative_coords[:, :, 0] += win_w - 1 |
|
|
|
|
|
return relative_coords.sum(-1) |
|
|
|
|
|
|
|
|
class WindowedAttentionHead(nn.Module): |
|
|
def __init__(self, head_dim, window_size, shift_windows=False, attn_drop=0.0): |
|
|
super().__init__() |
|
|
self.head_dim = head_dim |
|
|
self.window_size = window_size |
|
|
self.shift_windows = shift_windows |
|
|
self.attn_drop = attn_drop |
|
|
|
|
|
self.scale = self.head_dim**-0.5 |
|
|
self.window_area = self.window_size * 1 |
|
|
|
|
|
self.relative_position_bias_table = nn.Parameter( |
|
|
torch.zeros((2 * window_size - 1, 1)) |
|
|
) |
|
|
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"relative_position_index", get_relative_position_index1d(window_size) |
|
|
) |
|
|
|
|
|
self.drop_layer = nn.Dropout(attn_drop) if attn_drop > 0 else None |
|
|
|
|
|
if shift_windows: |
|
|
self.shift_size = window_size // 2 |
|
|
else: |
|
|
self.shift_size = 0 |
|
|
assert 0 <= self.shift_size < self.window_size, ( |
|
|
"shift_size must in 0-window_size" |
|
|
) |
|
|
|
|
|
def forward(self, q, k, v, train: bool = True): |
|
|
B, W, C = q.shape |
|
|
|
|
|
mask = None |
|
|
if self.shift_size > 0: |
|
|
img_mask = torch.zeros((1, W, 1), device=q.device) |
|
|
cnt = 0 |
|
|
for w in ( |
|
|
slice(0, -self.window_size), |
|
|
slice(-self.window_size, -self.shift_size), |
|
|
slice(-self.shift_size, None), |
|
|
): |
|
|
img_mask[:, w, :] = cnt |
|
|
cnt += 1 |
|
|
mask_windows = window_partition1d(img_mask, self.window_size) |
|
|
mask_windows = mask_windows.view(-1, self.window_size) |
|
|
mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|
|
mask = mask.masked_fill(mask != 0, -100.0).masked_fill(mask == 0, 0.0) |
|
|
|
|
|
q = torch.roll(q, shifts=-self.shift_size, dims=1) |
|
|
k = torch.roll(k, shifts=-self.shift_size, dims=1) |
|
|
v = torch.roll(v, shifts=-self.shift_size, dims=1) |
|
|
|
|
|
q = window_partition1d(q, self.window_size) |
|
|
k = window_partition1d(k, self.window_size) |
|
|
v = window_partition1d(v, self.window_size) |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
|
|
|
if train: |
|
|
attn = attn + self._get_rel_pos_bias() |
|
|
else: |
|
|
attn = attn + self._get_rel_pos_bias() |
|
|
|
|
|
if mask is not None: |
|
|
B_, N, _ = attn.shape |
|
|
num_win = mask.shape[0] |
|
|
attn = attn.view(B_ // num_win, num_win, N, N) + mask.unsqueeze(0) |
|
|
attn = attn.view(-1, N, N) |
|
|
attn = attn.softmax(dim=-1) |
|
|
else: |
|
|
attn = attn.softmax(dim=-1) |
|
|
|
|
|
if self.drop_layer is not None and train: |
|
|
attn = self.drop_layer(attn) |
|
|
|
|
|
x = attn @ v |
|
|
|
|
|
|
|
|
shifted_x = window_reverse1d(x, self.window_size, W=W) |
|
|
|
|
|
if self.shift_size > 0: |
|
|
x = torch.roll(shifted_x, shifts=self.shift_size, dims=1) |
|
|
else: |
|
|
x = shifted_x |
|
|
|
|
|
return x, attn |
|
|
|
|
|
def _get_rel_pos_bias(self): |
|
|
relative_position_bias = self.relative_position_bias_table[ |
|
|
self.relative_position_index.view(-1) |
|
|
].view(self.window_area, self.window_area, -1) |
|
|
relative_position_bias = relative_position_bias.permute(2, 0, 1) |
|
|
return relative_position_bias |
|
|
|
|
|
|
|
|
class AttentionHead(nn.Module): |
|
|
def __init__(self, head_dim, attn_drop=0.0): |
|
|
super().__init__() |
|
|
self.head_dim = head_dim |
|
|
self.scale = head_dim**-0.5 |
|
|
self.drop_layer = nn.Dropout(attn_drop) if attn_drop > 0 else None |
|
|
|
|
|
def forward(self, q, k, v, train: bool = True): |
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = attn.softmax(dim=-1) |
|
|
|
|
|
if self.drop_layer is not None and train: |
|
|
attn = self.drop_layer(attn) |
|
|
|
|
|
x = attn @ v |
|
|
return x, attn |
|
|
|
|
|
|
|
|
class WindowedMultiHeadAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
window_sizes, |
|
|
shift_windows=False, |
|
|
num_heads=8, |
|
|
qkv_bias=False, |
|
|
attn_drop=0.0, |
|
|
proj_drop=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = dim // num_heads |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
nn.init.xavier_uniform_(self.qkv.weight) |
|
|
|
|
|
if isinstance(window_sizes, int): |
|
|
window_sizes = _ntuple(num_heads)(window_sizes) |
|
|
else: |
|
|
assert len(window_sizes) == num_heads |
|
|
|
|
|
self.attn_heads = nn.ModuleList() |
|
|
for i in range(num_heads): |
|
|
ws_i = window_sizes[i] |
|
|
if ws_i == 0: |
|
|
self.attn_heads.append(AttentionHead(self.head_dim, attn_drop)) |
|
|
else: |
|
|
self.attn_heads.append( |
|
|
WindowedAttentionHead( |
|
|
self.head_dim, |
|
|
window_size=ws_i, |
|
|
shift_windows=shift_windows, |
|
|
attn_drop=attn_drop, |
|
|
) |
|
|
) |
|
|
|
|
|
self.proj = nn.Linear(dim, dim) |
|
|
nn.init.xavier_uniform_(self.proj.weight) |
|
|
self.drop_layer = nn.Dropout(proj_drop) if proj_drop > 0 else None |
|
|
|
|
|
def forward(self, x, train: bool = True): |
|
|
B, N, C = x.shape |
|
|
qkv = ( |
|
|
self.qkv(x) |
|
|
.reshape(B, N, 3, self.num_heads, C // self.num_heads) |
|
|
.permute(2, 3, 0, 1, 4) |
|
|
) |
|
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
|
|
o = [] |
|
|
for i in range(self.num_heads): |
|
|
head_i, attn_i = self.attn_heads[i](q[i], k[i], v[i], train=train) |
|
|
o.append(head_i.unsqueeze(0)) |
|
|
|
|
|
o = torch.cat(o, dim=0) |
|
|
o = o.permute(1, 2, 0, 3).reshape(B, N, -1) |
|
|
o = self.proj(o) |
|
|
|
|
|
if self.drop_layer is not None and train: |
|
|
o = self.drop_layer(o) |
|
|
|
|
|
return o |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
def __init__(self, dim, init_values=1e-5): |
|
|
super().__init__() |
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return x * self.gamma |
|
|
|
|
|
|
|
|
class BNWrapper(nn.Module): |
|
|
def __init__( |
|
|
self, num_features, use_running_average=True, use_bias=True, use_scale=True |
|
|
): |
|
|
super().__init__() |
|
|
self.bn = nn.BatchNorm1d(num_features, affine=use_scale or use_bias) |
|
|
|
|
|
def forward(self, x, train=True): |
|
|
return self.bn(x, train) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
num_heads, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=False, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
init_values=None, |
|
|
drop_path=0.0, |
|
|
act_layer=F.gelu, |
|
|
norm_layer=nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = norm_layer(dim) |
|
|
self.attn = Attention( |
|
|
dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
self.norm2 = norm_layer(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = Mlp( |
|
|
in_features=dim, |
|
|
hidden_features=mlp_hidden_dim, |
|
|
out_features=dim, |
|
|
activation=act_layer, |
|
|
drop=drop, |
|
|
) |
|
|
|
|
|
self.init_values = init_values |
|
|
if init_values is not None: |
|
|
self.layer_scale1 = LayerScale(dim, init_values) |
|
|
self.layer_scale2 = LayerScale(dim, init_values) |
|
|
|
|
|
def forward(self, x, train: bool = True): |
|
|
outputs1 = self.attn(self.norm1(x), train=train) |
|
|
|
|
|
if self.init_values is not None: |
|
|
outputs1 = self.layer_scale1(outputs1) |
|
|
|
|
|
x = x + self.drop_path(outputs1) if train else x + outputs1 |
|
|
|
|
|
outputs2 = self.mlp(self.norm2(x), train=train) |
|
|
|
|
|
if self.init_values is not None: |
|
|
outputs2 = self.layer_scale2(outputs2) |
|
|
|
|
|
x = x + self.drop_path(outputs2) if train else x + outputs2 |
|
|
return x |
|
|
|
|
|
|
|
|
class MWMHABlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
num_heads, |
|
|
window_sizes, |
|
|
shift_windows=False, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=False, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
init_values=None, |
|
|
drop_path=0.0, |
|
|
act_layer=F.gelu, |
|
|
norm_layer=nn.LayerNorm, |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = norm_layer(dim) |
|
|
self.wmha = WindowedMultiHeadAttention( |
|
|
dim, |
|
|
window_sizes=window_sizes, |
|
|
shift_windows=shift_windows, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
self.norm2 = norm_layer(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = Mlp( |
|
|
in_features=dim, |
|
|
hidden_features=mlp_hidden_dim, |
|
|
out_features=dim, |
|
|
activation=act_layer, |
|
|
drop=drop, |
|
|
) |
|
|
|
|
|
self.init_values = init_values |
|
|
if init_values is not None: |
|
|
self.layer_scale1 = LayerScale(dim, init_values) |
|
|
self.layer_scale2 = LayerScale(dim, init_values) |
|
|
|
|
|
def forward(self, x, train: bool = True): |
|
|
outputs1 = self.wmha(self.norm1(x), train=train) |
|
|
|
|
|
if self.init_values is not None: |
|
|
outputs1 = self.layer_scale1(outputs1) |
|
|
|
|
|
x = x + self.drop_path(outputs1) if train else x + outputs1 |
|
|
|
|
|
outputs2 = self.mlp(self.norm2(x), train=train) |
|
|
|
|
|
if self.init_values is not None: |
|
|
outputs2 = self.layer_scale2(outputs2) |
|
|
|
|
|
x = x + self.drop_path(outputs2) if train else x + outputs2 |
|
|
return x |
|
|
|