Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
############################################################# | |
# File: OSA.py | |
# Created Date: Tuesday April 28th 2022 | |
# Author: Chen Xuanhong | |
# Email: chenxuanhongzju@outlook.com | |
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm | |
# Modified By: Chen Xuanhong | |
# Copyright (c) 2020 Shanghai Jiao Tong University | |
############################################################# | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange, Reduce | |
from torch import einsum, nn | |
from .layernorm import LayerNorm2d | |
# helpers | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
def cast_tuple(val, length=1): | |
return val if isinstance(val, tuple) else ((val,) * length) | |
# helper classes | |
class PreNormResidual(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.fn = fn | |
def forward(self, x): | |
return self.fn(self.norm(x)) + x | |
class Conv_PreNormResidual(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = LayerNorm2d(dim) | |
self.fn = fn | |
def forward(self, x): | |
return self.fn(self.norm(x)) + x | |
class FeedForward(nn.Module): | |
def __init__(self, dim, mult=2, dropout=0.0): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
self.net = nn.Sequential( | |
nn.Linear(dim, inner_dim), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(inner_dim, dim), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
return self.net(x) | |
class Conv_FeedForward(nn.Module): | |
def __init__(self, dim, mult=2, dropout=0.0): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
self.net = nn.Sequential( | |
nn.Conv2d(dim, inner_dim, 1, 1, 0), | |
nn.GELU(), | |
nn.Dropout(dropout), | |
nn.Conv2d(inner_dim, dim, 1, 1, 0), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x): | |
return self.net(x) | |
class Gated_Conv_FeedForward(nn.Module): | |
def __init__(self, dim, mult=1, bias=False, dropout=0.0): | |
super().__init__() | |
hidden_features = int(dim * mult) | |
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) | |
self.dwconv = nn.Conv2d( | |
hidden_features * 2, | |
hidden_features * 2, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=hidden_features * 2, | |
bias=bias, | |
) | |
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) | |
def forward(self, x): | |
x = self.project_in(x) | |
x1, x2 = self.dwconv(x).chunk(2, dim=1) | |
x = F.gelu(x1) * x2 | |
x = self.project_out(x) | |
return x | |
# MBConv | |
class SqueezeExcitation(nn.Module): | |
def __init__(self, dim, shrinkage_rate=0.25): | |
super().__init__() | |
hidden_dim = int(dim * shrinkage_rate) | |
self.gate = nn.Sequential( | |
Reduce("b c h w -> b c", "mean"), | |
nn.Linear(dim, hidden_dim, bias=False), | |
nn.SiLU(), | |
nn.Linear(hidden_dim, dim, bias=False), | |
nn.Sigmoid(), | |
Rearrange("b c -> b c 1 1"), | |
) | |
def forward(self, x): | |
return x * self.gate(x) | |
class MBConvResidual(nn.Module): | |
def __init__(self, fn, dropout=0.0): | |
super().__init__() | |
self.fn = fn | |
self.dropsample = Dropsample(dropout) | |
def forward(self, x): | |
out = self.fn(x) | |
out = self.dropsample(out) | |
return out + x | |
class Dropsample(nn.Module): | |
def __init__(self, prob=0): | |
super().__init__() | |
self.prob = prob | |
def forward(self, x): | |
device = x.device | |
if self.prob == 0.0 or (not self.training): | |
return x | |
keep_mask = ( | |
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() | |
> self.prob | |
) | |
return x * keep_mask / (1 - self.prob) | |
def MBConv( | |
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 | |
): | |
hidden_dim = int(expansion_rate * dim_out) | |
stride = 2 if downsample else 1 | |
net = nn.Sequential( | |
nn.Conv2d(dim_in, hidden_dim, 1), | |
# nn.BatchNorm2d(hidden_dim), | |
nn.GELU(), | |
nn.Conv2d( | |
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim | |
), | |
# nn.BatchNorm2d(hidden_dim), | |
nn.GELU(), | |
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), | |
nn.Conv2d(hidden_dim, dim_out, 1), | |
# nn.BatchNorm2d(dim_out) | |
) | |
if dim_in == dim_out and not downsample: | |
net = MBConvResidual(net, dropout=dropout) | |
return net | |
# attention related classes | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_head=32, | |
dropout=0.0, | |
window_size=7, | |
with_pe=True, | |
): | |
super().__init__() | |
assert ( | |
dim % dim_head | |
) == 0, "dimension should be divisible by dimension per head" | |
self.heads = dim // dim_head | |
self.scale = dim_head**-0.5 | |
self.with_pe = with_pe | |
self.to_qkv = nn.Linear(dim, dim * 3, bias=False) | |
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) | |
self.to_out = nn.Sequential( | |
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) | |
) | |
# relative positional bias | |
if self.with_pe: | |
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) | |
pos = torch.arange(window_size) | |
grid = torch.stack(torch.meshgrid(pos, pos)) | |
grid = rearrange(grid, "c i j -> (i j) c") | |
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( | |
grid, "j ... -> 1 j ..." | |
) | |
rel_pos += window_size - 1 | |
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( | |
dim=-1 | |
) | |
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) | |
def forward(self, x): | |
batch, height, width, window_height, window_width, _, device, h = ( | |
*x.shape, | |
x.device, | |
self.heads, | |
) | |
# flatten | |
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") | |
# project for queries, keys, values | |
q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
# split heads | |
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) | |
# scale | |
q = q * self.scale | |
# sim | |
sim = einsum("b h i d, b h j d -> b h i j", q, k) | |
# add positional bias | |
if self.with_pe: | |
bias = self.rel_pos_bias(self.rel_pos_indices) | |
sim = sim + rearrange(bias, "i j h -> h i j") | |
# attention | |
attn = self.attend(sim) | |
# aggregate | |
out = einsum("b h i j, b h j d -> b h i d", attn, v) | |
# merge heads | |
out = rearrange( | |
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width | |
) | |
# combine heads out | |
out = self.to_out(out) | |
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) | |
class Block_Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
dim_head=32, | |
bias=False, | |
dropout=0.0, | |
window_size=7, | |
with_pe=True, | |
): | |
super().__init__() | |
assert ( | |
dim % dim_head | |
) == 0, "dimension should be divisible by dimension per head" | |
self.heads = dim // dim_head | |
self.ps = window_size | |
self.scale = dim_head**-0.5 | |
self.with_pe = with_pe | |
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) | |
self.qkv_dwconv = nn.Conv2d( | |
dim * 3, | |
dim * 3, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=dim * 3, | |
bias=bias, | |
) | |
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) | |
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
def forward(self, x): | |
# project for queries, keys, values | |
b, c, h, w = x.shape | |
qkv = self.qkv_dwconv(self.qkv(x)) | |
q, k, v = qkv.chunk(3, dim=1) | |
# split heads | |
q, k, v = map( | |
lambda t: rearrange( | |
t, | |
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", | |
h=self.heads, | |
w1=self.ps, | |
w2=self.ps, | |
), | |
(q, k, v), | |
) | |
# scale | |
q = q * self.scale | |
# sim | |
sim = einsum("b h i d, b h j d -> b h i j", q, k) | |
# attention | |
attn = self.attend(sim) | |
# aggregate | |
out = einsum("b h i j, b h j d -> b h i d", attn, v) | |
# merge heads | |
out = rearrange( | |
out, | |
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", | |
x=h // self.ps, | |
y=w // self.ps, | |
head=self.heads, | |
w1=self.ps, | |
w2=self.ps, | |
) | |
out = self.to_out(out) | |
return out | |
class Channel_Attention(nn.Module): | |
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): | |
super(Channel_Attention, self).__init__() | |
self.heads = heads | |
self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) | |
self.ps = window_size | |
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) | |
self.qkv_dwconv = nn.Conv2d( | |
dim * 3, | |
dim * 3, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=dim * 3, | |
bias=bias, | |
) | |
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.qkv_dwconv(self.qkv(x)) | |
qkv = qkv.chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange( | |
t, | |
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", | |
ph=self.ps, | |
pw=self.ps, | |
head=self.heads, | |
), | |
qkv, | |
) | |
q = F.normalize(q, dim=-1) | |
k = F.normalize(k, dim=-1) | |
attn = (q @ k.transpose(-2, -1)) * self.temperature | |
attn = attn.softmax(dim=-1) | |
out = attn @ v | |
out = rearrange( | |
out, | |
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", | |
h=h // self.ps, | |
w=w // self.ps, | |
ph=self.ps, | |
pw=self.ps, | |
head=self.heads, | |
) | |
out = self.project_out(out) | |
return out | |
class Channel_Attention_grid(nn.Module): | |
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): | |
super(Channel_Attention_grid, self).__init__() | |
self.heads = heads | |
self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) | |
self.ps = window_size | |
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) | |
self.qkv_dwconv = nn.Conv2d( | |
dim * 3, | |
dim * 3, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
groups=dim * 3, | |
bias=bias, | |
) | |
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
qkv = self.qkv_dwconv(self.qkv(x)) | |
qkv = qkv.chunk(3, dim=1) | |
q, k, v = map( | |
lambda t: rearrange( | |
t, | |
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", | |
ph=self.ps, | |
pw=self.ps, | |
head=self.heads, | |
), | |
qkv, | |
) | |
q = F.normalize(q, dim=-1) | |
k = F.normalize(k, dim=-1) | |
attn = (q @ k.transpose(-2, -1)) * self.temperature | |
attn = attn.softmax(dim=-1) | |
out = attn @ v | |
out = rearrange( | |
out, | |
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", | |
h=h // self.ps, | |
w=w // self.ps, | |
ph=self.ps, | |
pw=self.ps, | |
head=self.heads, | |
) | |
out = self.project_out(out) | |
return out | |
class OSA_Block(nn.Module): | |
def __init__( | |
self, | |
channel_num=64, | |
bias=True, | |
ffn_bias=True, | |
window_size=8, | |
with_pe=False, | |
dropout=0.0, | |
): | |
super(OSA_Block, self).__init__() | |
w = window_size | |
self.layer = nn.Sequential( | |
MBConv( | |
channel_num, | |
channel_num, | |
downsample=False, | |
expansion_rate=1, | |
shrinkage_rate=0.25, | |
), | |
Rearrange( | |
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w | |
), # block-like attention | |
PreNormResidual( | |
channel_num, | |
Attention( | |
dim=channel_num, | |
dim_head=channel_num // 4, | |
dropout=dropout, | |
window_size=window_size, | |
with_pe=with_pe, | |
), | |
), | |
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), | |
Conv_PreNormResidual( | |
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) | |
), | |
# channel-like attention | |
Conv_PreNormResidual( | |
channel_num, | |
Channel_Attention( | |
dim=channel_num, heads=4, dropout=dropout, window_size=window_size | |
), | |
), | |
Conv_PreNormResidual( | |
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) | |
), | |
Rearrange( | |
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w | |
), # grid-like attention | |
PreNormResidual( | |
channel_num, | |
Attention( | |
dim=channel_num, | |
dim_head=channel_num // 4, | |
dropout=dropout, | |
window_size=window_size, | |
with_pe=with_pe, | |
), | |
), | |
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), | |
Conv_PreNormResidual( | |
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) | |
), | |
# channel-like attention | |
Conv_PreNormResidual( | |
channel_num, | |
Channel_Attention_grid( | |
dim=channel_num, heads=4, dropout=dropout, window_size=window_size | |
), | |
), | |
Conv_PreNormResidual( | |
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) | |
), | |
) | |
def forward(self, x): | |
out = self.layer(x) | |
return out | |