Spaces:
Running
Running
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from models.helpers import DropPath, drop_path | |
# this file only provides the 3 blocks used in VAR transformer | |
__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead'] | |
# automatically import fused operators | |
dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None | |
try: | |
from flash_attn.ops.layer_norm import dropout_add_layer_norm | |
from flash_attn.ops.fused_dense import fused_mlp_func | |
except ImportError: pass | |
# automatically import faster attention implementations | |
try: from xformers.ops import memory_efficient_attention | |
except ImportError: pass | |
try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq | |
except ImportError: pass | |
try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc | |
except ImportError: | |
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0): | |
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL | |
if attn_mask is not None: attn.add_(attn_mask) | |
return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value | |
class FFN(nn.Module): | |
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True): | |
super().__init__() | |
self.fused_mlp_func = fused_mlp_func if fused_if_available else None | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = nn.GELU(approximate='tanh') | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity() | |
def forward(self, x): | |
if self.fused_mlp_func is not None: | |
return self.drop(self.fused_mlp_func( | |
x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias, | |
activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0, | |
heuristic=0, process_group=None, | |
)) | |
else: | |
return self.drop(self.fc2( self.act(self.fc1(x)) )) | |
def extra_repr(self) -> str: | |
return f'fused_mlp_func={self.fused_mlp_func is not None}' | |
class SelfAttention(nn.Module): | |
def __init__( | |
self, block_idx, embed_dim=768, num_heads=12, | |
attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True, | |
): | |
super().__init__() | |
assert embed_dim % num_heads == 0 | |
self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64 | |
self.attn_l2_norm = attn_l2_norm | |
if self.attn_l2_norm: | |
self.scale = 1 | |
self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True) | |
self.max_scale_mul = torch.log(torch.tensor(100)).item() | |
else: | |
self.scale = 0.25 / math.sqrt(self.head_dim) | |
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) | |
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim)) | |
self.register_buffer('zero_k_bias', torch.zeros(embed_dim)) | |
self.proj = nn.Linear(embed_dim, embed_dim) | |
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity() | |
self.attn_drop: float = attn_drop | |
self.using_flash = flash_if_available and flash_attn_func is not None | |
self.using_xform = flash_if_available and memory_efficient_attention is not None | |
# only used during inference | |
self.caching, self.cached_k, self.cached_v = False, None, None | |
def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None | |
# NOTE: attn_bias is None during inference because kv cache is enabled | |
def forward(self, x, attn_bias): | |
B, L, C = x.shape | |
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) | |
main_type = qkv.dtype | |
# qkv: BL3Hc | |
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32 | |
if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc | |
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc | |
if self.attn_l2_norm: | |
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() | |
if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1 | |
q = F.normalize(q, dim=-1).mul(scale_mul) | |
k = F.normalize(k, dim=-1) | |
if self.caching: | |
if self.cached_k is None: self.cached_k = k; self.cached_v = v | |
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat) | |
dropout_p = self.attn_drop if self.training else 0.0 | |
if using_flash: | |
oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C) | |
elif self.using_xform: | |
oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C) | |
else: | |
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C) | |
return self.proj_drop(self.proj(oup)) | |
# attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL | |
# attn = self.attn_drop(attn.softmax(dim=-1)) | |
# oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC | |
def extra_repr(self) -> str: | |
return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}' | |
class AdaLNSelfAttn(nn.Module): | |
def __init__( | |
self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer, | |
num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False, | |
flash_if_available=False, fused_if_available=True, | |
): | |
super(AdaLNSelfAttn, self).__init__() | |
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim | |
self.C, self.D = embed_dim, cond_dim | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available) | |
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available) | |
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False) | |
self.shared_aln = shared_aln | |
if self.shared_aln: | |
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5) | |
else: | |
lin = nn.Linear(cond_dim, 6*embed_dim) | |
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) | |
self.fused_add_norm_fn = None | |
# NOTE: attn_bias is None during inference because kv cache is enabled | |
def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim | |
if self.shared_aln: | |
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C | |
else: | |
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) | |
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1)) | |
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used | |
return x | |
def extra_repr(self) -> str: | |
return f'shared_aln={self.shared_aln}' | |
class AdaLNBeforeHead(nn.Module): | |
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim | |
super().__init__() | |
self.C, self.D = C, D | |
self.ln_wo_grad = norm_layer(C, elementwise_affine=False) | |
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C)) | |
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor): | |
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2) | |
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift) | |