import torch |
import numpy as np |
from torch import nn |
from torch.nn import functional as F |
from einops.layers.torch import Rearrange |
import math |
def default(val, default_val): |
return val if val is not None else default_val |
def init_(tensor): |
dim = tensor.shape[-1] |
std = 1 / math.sqrt(dim) |
tensor.uniform_(-std, std) |
return tensor |
class Residual(nn.Module): |
def __init__(self, fn): |
super().__init__() |
self.fn = fn |
def forward(self, x): |
return x + self.fn(x) |
class PreNorm(nn.Module): |
def __init__(self, dim, fn): |
super().__init__() |
self.fn = fn |
self.norm = nn.LayerNorm(dim) |
def forward(self, x): |
x = self.norm(x) |
return self.fn(x) |
class GELU_(nn.Module): |
def forward(self, x): |
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_ |
class FeedForward(nn.Module): |
def __init__(self, dim, hidden_dim, dropout = 0., activation = None, glu = False): |
super().__init__() |
activation = default(activation, GELU) |
self.glu = glu |
self.w1 = nn.Linear(dim, hidden_dim * (2 if glu else 1)) |
self.act = activation() |
self.dropout = nn.Dropout(dropout) |
self.w2 = nn.Linear(hidden_dim, dim) |
def forward(self, x, **kwargs): |
if not self.glu: |
x = self.w1(x) |
x = self.act(x) |
else: |
x, v = self.w1(x).chunk(2, dim=-1) |
x = self.act(x) * v |
x = self.dropout(x) |
x = self.w2(x) |
return x |
class LinformerSelfAttention(nn.Module): |
def __init__(self, dim, seq_len, k = 16, heads = 4, dim_head = None, one_kv_head = False, share_kv = False, dropout = 0.): |
super().__init__() |
assert (dim % heads) == 0, 'dimension must be divisible by the number of heads' |
self.seq_len = seq_len |
self.k = k |
self.heads = heads |
dim_head = default(dim_head, dim // heads) |
self.dim_head = dim_head |
self.to_q = nn.Linear(dim, dim_head * heads, bias = False) |
kv_dim = dim_head if one_kv_head else (dim_head * heads) |
self.to_k = nn.Linear(dim, kv_dim, bias = False) |
self.proj_k = nn.Parameter(init_(torch.zeros(seq_len, k))) |
self.share_kv = share_kv |
if not share_kv: |
self.to_v = nn.Linear(dim, kv_dim, bias = False) |
self.proj_v = nn.Parameter(init_(torch.zeros(seq_len, k))) |
self.dropout = nn.Dropout(dropout) |
self.to_out = nn.Linear(dim_head * heads, dim) |
def forward(self, x, context = None, **kwargs): |
b, n, d, d_h, h, k = *x.shape, self.dim_head, self.heads, self.k |
kv_len = n if context is None else context.shape[1] |
assert kv_len == self.seq_len, f'the sequence length of the key / values must be {self.seq_len} - {kv_len} given' |
queries = self.to_q(x) |
proj_seq_len = lambda args: torch.einsum('bnd,nk->bkd', *args) |
kv_input = x if context is None else context |
keys = self.to_k(kv_input) |
values = self.to_v(kv_input) if not self.share_kv else keys |
kv_projs = (self.proj_k, self.proj_v if not self.share_kv else self.proj_k) |
keys, values = map(proj_seq_len, zip((keys, values), kv_projs)) |
queries = queries.reshape(b, n, h, -1).transpose(1, 2) |
merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1) |
keys, values = map(merge_key_values, (keys, values)) |
dots = torch.einsum('bhnd,bhkd->bhnk', queries, keys) * (d_h ** -0.5) |
attn = dots.softmax(dim=-1) |
attn = self.dropout(attn) |
out = torch.einsum('bhnk,bhkd->bhnd', attn, values) |
out = out.transpose(1, 2).reshape(b, n, -1) |
return self.to_out(out) |
class LinformerBlock(nn.Module): |
def __init__(self, d_model, d_ffn, seq_len,dropout): |
super().__init__() |
self.norm = nn.LayerNorm(d_model) |
self.Linformer_unit = LinformerSelfAttention(d_model, seq_len, k = 256, heads = 8, dim_head = None, one_kv_head = False, share_kv = False, dropout=dropout) |
self.ffn = FeedForward(d_model,d_ffn,dropout) |
def forward(self, x): |
residual = x |
x = self.norm(x) |
x = self.Linformer_unit(x) |
x = x + residual |
residual = x |
x = self.norm(x) |
x = self.ffn(x) |
out = x + residual |
return out |
class LinearizerGatingUnit(nn.Module): |
def __init__(self,d_model,d_ffn,seq_len,dropout): |
super().__init__() |
self.proj = nn.Linear(d_model,d_model) |
self.Linz = LinformerBlock( |
d_model, d_ffn, seq_len,dropout |
) |
def forward(self, x): |
u, v = x, x |
u = self.proj(u) |
v = self.Linz(v) |
out = u * v |
return out |
class LinearizerBlock(nn.Module): |
def __init__(self, d_model,d_ffn,seq_len,dropout): |
super().__init__() |
self.norm = nn.LayerNorm(d_model) |
self.lgu = LinearizerGatingUnit(d_model,d_ffn,seq_len,dropout) |
self.ffn = FeedForward(d_model,d_ffn,dropout) |
def forward(self, x): |
residual = x |
x = self.norm(x) |
x = self.lgu(x) |
x = x + residual |
residual = x |
x = self.norm(x) |
x = self.ffn(x) |
out = x + residual |
return out |
class Linearizer(nn.Module): |
def __init__(self, d_model, d_ffn,seq_len, num_layers,dropout): |
super().__init__() |
self.model = nn.Sequential( |
*[LinearizerBlock(d_model,d_ffn,seq_len,dropout) for _ in range(num_layers)] |
) |
def forward(self, x): |
return self.model(x) |