account18hackathon
commited on
Commit
•
26a9f0f
1
Parent(s):
4e40454
Upload 4 files
Browse files
performer_pytorch/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention
|
performer_pytorch/autoregressive_wrapper.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.utils.rnn import pad_sequence
|
6 |
+
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
|
10 |
+
def exists(val):
|
11 |
+
return val is not None
|
12 |
+
|
13 |
+
def top_p(logits, thres = 0.9):
|
14 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
15 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
16 |
+
|
17 |
+
sorted_indices_to_remove = cum_probs > (1 - thres)
|
18 |
+
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
19 |
+
sorted_indices_to_remove[:, 0] = 0
|
20 |
+
|
21 |
+
sorted_logits[sorted_indices_to_remove] = float('-inf')
|
22 |
+
return sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
23 |
+
|
24 |
+
def top_k(logits, thres = 0.9):
|
25 |
+
k = int((1 - thres) * logits.shape[-1])
|
26 |
+
val, ind = torch.topk(logits, k)
|
27 |
+
probs = torch.full_like(logits, float('-inf'))
|
28 |
+
probs.scatter_(1, ind, val)
|
29 |
+
return probs
|
30 |
+
|
31 |
+
def repetition_penalty_fn(logits, ctx, theta=1.2):
|
32 |
+
w = torch.ones(logits.shape[-1], dtype=torch.float, device=logits.device)
|
33 |
+
for i in torch.unique(ctx):
|
34 |
+
w[i] = theta
|
35 |
+
return logits/w
|
36 |
+
|
37 |
+
class AutoregressiveWrapper(nn.Module):
|
38 |
+
def __init__(self, net, ignore_index = 0, pad_value = 0):
|
39 |
+
super().__init__()
|
40 |
+
self.pad_value = pad_value
|
41 |
+
self.ignore_index = ignore_index
|
42 |
+
|
43 |
+
self.net = net
|
44 |
+
self.max_seq_len = net.max_seq_len
|
45 |
+
|
46 |
+
@torch.no_grad()
|
47 |
+
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, repetition_penalty=1.0, repetition_penalty_ctx=32, **kwargs):
|
48 |
+
was_training = self.net.training
|
49 |
+
num_dims = len(start_tokens.shape)
|
50 |
+
|
51 |
+
if num_dims == 1:
|
52 |
+
start_tokens = start_tokens[None, :]
|
53 |
+
|
54 |
+
b, t = start_tokens.shape
|
55 |
+
|
56 |
+
self.net.eval()
|
57 |
+
out = start_tokens
|
58 |
+
input_mask = kwargs.pop('mask', None)
|
59 |
+
|
60 |
+
if input_mask is None:
|
61 |
+
input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
|
62 |
+
|
63 |
+
# in case of conditional generation, if enc_mask is not provided use the correct context_mask
|
64 |
+
context_mask = kwargs.pop('context_mask', None)
|
65 |
+
|
66 |
+
if 'context' in kwargs and not exists(context_mask):
|
67 |
+
context = kwargs['context']
|
68 |
+
context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=out.device)
|
69 |
+
|
70 |
+
kwargs.update(context_mask = context_mask)
|
71 |
+
|
72 |
+
for _ in range(seq_len):
|
73 |
+
x = out[:, -self.max_seq_len:]
|
74 |
+
input_mask = input_mask[:, -self.max_seq_len:]
|
75 |
+
logits = self.net(x, mask=input_mask, **kwargs)[:, -1, :]
|
76 |
+
if repetition_penalty > 1.0:
|
77 |
+
logits = repetition_penalty_fn(logits, out[-repetition_penalty_ctx:], theta=repetition_penalty)
|
78 |
+
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
|
79 |
+
probs = F.softmax(filtered_logits / temperature, dim=-1)
|
80 |
+
sample = torch.multinomial(probs, 1)
|
81 |
+
|
82 |
+
out = torch.cat((out, sample), dim=-1)
|
83 |
+
input_mask = F.pad(input_mask, (0, 1), value=True)
|
84 |
+
|
85 |
+
if eos_token is not None and (sample == eos_token).all():
|
86 |
+
break
|
87 |
+
|
88 |
+
out = out[:, t:]
|
89 |
+
|
90 |
+
if num_dims == 1:
|
91 |
+
out = out.squeeze(0)
|
92 |
+
|
93 |
+
self.net.train(was_training)
|
94 |
+
return out
|
95 |
+
|
96 |
+
def forward(self, x, **kwargs):
|
97 |
+
xi = x[:, :-1]
|
98 |
+
xo = x[:, 1:]
|
99 |
+
|
100 |
+
# help auto-solve an area of confusion around input masks in auto-regressive
|
101 |
+
# if user supplies a mask that is only off by one from the source sequence, resolve it for them
|
102 |
+
mask = kwargs.pop('mask', None)
|
103 |
+
if mask is not None and mask.shape[1] == x.shape[1]:
|
104 |
+
mask = mask[:, :-1]
|
105 |
+
kwargs.update(mask = mask)
|
106 |
+
|
107 |
+
out = self.net(xi, **kwargs)
|
108 |
+
|
109 |
+
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
|
110 |
+
|
111 |
+
#pdb.set_trace()
|
112 |
+
|
113 |
+
return loss
|
performer_pytorch/performer_pytorch.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from torch.cuda.amp import autocast
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
from contextlib import contextmanager
|
11 |
+
|
12 |
+
from local_attention import LocalAttention
|
13 |
+
from performer_pytorch.reversible import ReversibleSequence, SequentialSequence
|
14 |
+
|
15 |
+
import pdb
|
16 |
+
|
17 |
+
try:
|
18 |
+
from apex import amp
|
19 |
+
APEX_AVAILABLE = True
|
20 |
+
except:
|
21 |
+
APEX_AVAILABLE = False
|
22 |
+
|
23 |
+
# helpers
|
24 |
+
|
25 |
+
def exists(val):
|
26 |
+
return val is not None
|
27 |
+
|
28 |
+
def empty(tensor):
|
29 |
+
return tensor.numel() == 0
|
30 |
+
|
31 |
+
def default(val, d):
|
32 |
+
return val if exists(val) else d
|
33 |
+
|
34 |
+
@contextmanager
|
35 |
+
def null_context():
|
36 |
+
yield
|
37 |
+
|
38 |
+
def cast_tuple(val):
|
39 |
+
return (val,) if not isinstance(val, tuple) else val
|
40 |
+
|
41 |
+
# def get_module_device(module):
|
42 |
+
# return next(module.parameters).device
|
43 |
+
|
44 |
+
def get_module_device(module):
|
45 |
+
try:
|
46 |
+
return next(module.parameters()).device
|
47 |
+
except StopIteration:
|
48 |
+
# For nn.DataParallel compatibility in PyTorch 1.5
|
49 |
+
def find_tensor_attributes(module):
|
50 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
51 |
+
return tuples
|
52 |
+
gen = module._named_members(get_members_fn=find_tensor_attributes)
|
53 |
+
first_tuple = next(gen)
|
54 |
+
return first_tuple[1].device
|
55 |
+
|
56 |
+
def find_modules(nn_module, type):
|
57 |
+
return [module for module in nn_module.modules() if isinstance(module, type)]
|
58 |
+
|
59 |
+
class Always(nn.Module):
|
60 |
+
def __init__(self, val):
|
61 |
+
super().__init__()
|
62 |
+
self.val = val
|
63 |
+
|
64 |
+
def forward(self, *args, **kwargs):
|
65 |
+
return self.val
|
66 |
+
|
67 |
+
# kernel functions
|
68 |
+
|
69 |
+
# transcribed from jax to pytorch from
|
70 |
+
# https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py
|
71 |
+
|
72 |
+
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
|
73 |
+
b, h, *_ = data.shape
|
74 |
+
|
75 |
+
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
76 |
+
|
77 |
+
ratio = (projection_matrix.shape[0] ** -0.5)
|
78 |
+
|
79 |
+
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
80 |
+
projection = projection.type_as(data)
|
81 |
+
|
82 |
+
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
83 |
+
|
84 |
+
diag_data = data ** 2
|
85 |
+
diag_data = torch.sum(diag_data, dim=-1)
|
86 |
+
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
|
87 |
+
diag_data = diag_data.unsqueeze(dim=-1)
|
88 |
+
|
89 |
+
if is_query:
|
90 |
+
data_dash = ratio * (
|
91 |
+
torch.exp(data_dash - diag_data -
|
92 |
+
torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
|
93 |
+
else:
|
94 |
+
data_dash = ratio * (
|
95 |
+
torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
|
96 |
+
|
97 |
+
return data_dash.type_as(data)
|
98 |
+
|
99 |
+
def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None):
|
100 |
+
b, h, *_ = data.shape
|
101 |
+
|
102 |
+
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
|
103 |
+
|
104 |
+
if projection_matrix is None:
|
105 |
+
return kernel_fn(data_normalizer * data) + kernel_epsilon
|
106 |
+
|
107 |
+
projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
|
108 |
+
projection = projection.type_as(data)
|
109 |
+
|
110 |
+
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
|
111 |
+
|
112 |
+
data_prime = kernel_fn(data_dash) + kernel_epsilon
|
113 |
+
return data_prime.type_as(data)
|
114 |
+
|
115 |
+
def orthogonal_matrix_chunk(cols, device = None):
|
116 |
+
unstructured_block = torch.randn((cols, cols), device = device)
|
117 |
+
q, r = torch.qr(unstructured_block.cpu(), some = True)
|
118 |
+
q, r = map(lambda t: t.to(device), (q, r))
|
119 |
+
return q.t()
|
120 |
+
|
121 |
+
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None):
|
122 |
+
nb_full_blocks = int(nb_rows / nb_columns)
|
123 |
+
|
124 |
+
block_list = []
|
125 |
+
|
126 |
+
for _ in range(nb_full_blocks):
|
127 |
+
q = orthogonal_matrix_chunk(nb_columns, device = device)
|
128 |
+
block_list.append(q)
|
129 |
+
|
130 |
+
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
131 |
+
if remaining_rows > 0:
|
132 |
+
q = orthogonal_matrix_chunk(nb_columns, device = device)
|
133 |
+
block_list.append(q[:remaining_rows])
|
134 |
+
|
135 |
+
final_matrix = torch.cat(block_list)
|
136 |
+
|
137 |
+
if scaling == 0:
|
138 |
+
multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
|
139 |
+
elif scaling == 1:
|
140 |
+
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
|
141 |
+
else:
|
142 |
+
raise ValueError(f'Invalid scaling {scaling}')
|
143 |
+
|
144 |
+
return torch.diag(multiplier) @ final_matrix
|
145 |
+
|
146 |
+
# linear attention classes with softmax kernel
|
147 |
+
|
148 |
+
# non-causal linear attention
|
149 |
+
def linear_attention(q, k, v):
|
150 |
+
k_cumsum = k.sum(dim = -2)
|
151 |
+
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
|
152 |
+
context = torch.einsum('...nd,...ne->...de', k, v)
|
153 |
+
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
|
154 |
+
return out
|
155 |
+
|
156 |
+
# efficient causal linear attention, created by EPFL
|
157 |
+
# TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back
|
158 |
+
def causal_linear_attention(q, k, v, eps = 1e-6):
|
159 |
+
from fast_transformers.causal_product import CausalDotProduct
|
160 |
+
autocast_enabled = torch.is_autocast_enabled()
|
161 |
+
is_half = isinstance(q, torch.cuda.HalfTensor)
|
162 |
+
assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available'
|
163 |
+
cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False)
|
164 |
+
|
165 |
+
causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply
|
166 |
+
|
167 |
+
k_cumsum = k.cumsum(dim=-2) + eps
|
168 |
+
D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))
|
169 |
+
|
170 |
+
with cuda_context():
|
171 |
+
if autocast_enabled:
|
172 |
+
q, k, v = map(lambda t: t.float(), (q, k, v))
|
173 |
+
|
174 |
+
out = causal_dot_product_fn(q, k, v)
|
175 |
+
|
176 |
+
out = torch.einsum('...nd,...n->...nd', out, D_inv)
|
177 |
+
return out
|
178 |
+
|
179 |
+
# inefficient causal linear attention, without cuda code, for reader's reference
|
180 |
+
# not being used
|
181 |
+
def causal_linear_attention_noncuda(q, k, v, chunk_size = 128):
|
182 |
+
last_k_cumsum = 0
|
183 |
+
last_context_cumsum = 0
|
184 |
+
outs = []
|
185 |
+
|
186 |
+
for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))):
|
187 |
+
k_cumsum = last_k_cumsum + k.cumsum(dim=-2)
|
188 |
+
|
189 |
+
D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))
|
190 |
+
context = torch.einsum('...nd,...ne->...nde', k, v)
|
191 |
+
context_cumsum = last_context_cumsum + context.cumsum(dim=-3)
|
192 |
+
out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv)
|
193 |
+
|
194 |
+
last_k_cumsum = k_cumsum[:, :, -1:]
|
195 |
+
last_context_cumsum = context_cumsum[:, :, -1:]
|
196 |
+
outs.append(out)
|
197 |
+
|
198 |
+
return torch.cat(outs, dim = -2)
|
199 |
+
|
200 |
+
def norm_tensor(tensor, dim=-1):
|
201 |
+
return tensor / tensor.sum(dim=dim).unsqueeze(dim)
|
202 |
+
|
203 |
+
class FastAttention(nn.Module):
|
204 |
+
def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False):
|
205 |
+
super().__init__()
|
206 |
+
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
207 |
+
|
208 |
+
self.dim_heads = dim_heads
|
209 |
+
self.nb_features = nb_features
|
210 |
+
self.ortho_scaling = ortho_scaling
|
211 |
+
|
212 |
+
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling)
|
213 |
+
projection_matrix = self.create_projection()
|
214 |
+
self.register_buffer('projection_matrix', projection_matrix)
|
215 |
+
|
216 |
+
self.generalized_attention = generalized_attention
|
217 |
+
self.kernel_fn = kernel_fn
|
218 |
+
|
219 |
+
# if this is turned on, no projection will be used
|
220 |
+
# queries and keys will be softmax-ed as in the original efficient attention paper
|
221 |
+
self.no_projection = no_projection
|
222 |
+
|
223 |
+
self.causal = causal
|
224 |
+
if causal:
|
225 |
+
try:
|
226 |
+
import fast_transformers.causal_product.causal_product_cuda
|
227 |
+
self.causal_linear_fn = partial(causal_linear_attention)
|
228 |
+
except ImportError:
|
229 |
+
print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
|
230 |
+
self.causal_linear_fn = causal_linear_attention_noncuda
|
231 |
+
|
232 |
+
@torch.no_grad()
|
233 |
+
def redraw_projection_matrix(self, device):
|
234 |
+
projections = self.create_projection(device = device)
|
235 |
+
self.projection_matrix.copy_(projections)
|
236 |
+
del projections
|
237 |
+
|
238 |
+
def forward(self, q, k, v, output_attentions = False):
|
239 |
+
device = q.device
|
240 |
+
# inds = [8060, 8064, 6243, 8575, 10342, 10913, 9366, 993, 7796, 5210, 5212, 5504, 6851, 6559, 5508, 13107, 13820]
|
241 |
+
if self.no_projection:
|
242 |
+
q = q.softmax(dim = -1)
|
243 |
+
k = torch.exp(k) if self.causal else k.softmax(dim = -2)
|
244 |
+
|
245 |
+
elif self.generalized_attention:
|
246 |
+
create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
|
247 |
+
q, k = map(create_kernel, (q, k))
|
248 |
+
|
249 |
+
else:
|
250 |
+
create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
|
251 |
+
q = create_kernel(q, is_query = True)
|
252 |
+
k = create_kernel(k, is_query = False)
|
253 |
+
|
254 |
+
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
255 |
+
out = attn_fn(q, k, v)
|
256 |
+
if output_attentions:
|
257 |
+
v_diag = torch.eye(v.shape[-2]).to(device)
|
258 |
+
v_diag = v_diag.unsqueeze(0).unsqueeze(0).repeat(v.shape[0],v.shape[1],1,1)
|
259 |
+
# attn_weights = torch.zeros(1, 1, len(inds), len(inds)).to(device).to(torch.float16)
|
260 |
+
# attn_weights = torch.zeros(1, q.shape[1], len(inds), len(inds)).to(device).to(torch.float16)
|
261 |
+
attn_weights = torch.zeros(1, 1, q.shape[2], q.shape[2]).to(device).to(torch.float16)
|
262 |
+
for head_dim in range(q.shape[1]):
|
263 |
+
# attn_weights[0, head_dim] = torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16)))[0, inds][:, inds]
|
264 |
+
attn_weights += torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16)))
|
265 |
+
# attn_weights += norm_tensor(torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))), dim=-1)
|
266 |
+
attn_weights /= q.shape[1]
|
267 |
+
return out, attn_weights
|
268 |
+
else:
|
269 |
+
return out
|
270 |
+
|
271 |
+
# classes
|
272 |
+
|
273 |
+
class ReZero(nn.Module):
|
274 |
+
def __init__(self, fn):
|
275 |
+
super().__init__()
|
276 |
+
self.g = nn.Parameter(torch.tensor(1e-3))
|
277 |
+
self.fn = fn
|
278 |
+
|
279 |
+
def forward(self, x, **kwargs):
|
280 |
+
return self.fn(x, **kwargs) * self.g
|
281 |
+
|
282 |
+
class PreScaleNorm(nn.Module):
|
283 |
+
def __init__(self, dim, fn, eps=1e-5):
|
284 |
+
super().__init__()
|
285 |
+
self.fn = fn
|
286 |
+
self.g = nn.Parameter(torch.ones(1))
|
287 |
+
self.eps = eps
|
288 |
+
|
289 |
+
def forward(self, x, **kwargs):
|
290 |
+
n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
|
291 |
+
x = x / n * self.g
|
292 |
+
return self.fn(x, **kwargs)
|
293 |
+
|
294 |
+
class PreLayerNorm(nn.Module):
|
295 |
+
def __init__(self, dim, fn):
|
296 |
+
super().__init__()
|
297 |
+
self.norm = nn.LayerNorm(dim)
|
298 |
+
self.fn = fn
|
299 |
+
def forward(self, x, **kwargs):
|
300 |
+
return self.fn(self.norm(x), **kwargs)
|
301 |
+
|
302 |
+
class Chunk(nn.Module):
|
303 |
+
def __init__(self, chunks, fn, along_dim = -1):
|
304 |
+
super().__init__()
|
305 |
+
self.dim = along_dim
|
306 |
+
self.chunks = chunks
|
307 |
+
self.fn = fn
|
308 |
+
|
309 |
+
def forward(self, x, **kwargs):
|
310 |
+
if self.chunks == 1:
|
311 |
+
return self.fn(x, **kwargs)
|
312 |
+
chunks = x.chunk(self.chunks, dim = self.dim)
|
313 |
+
return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)
|
314 |
+
|
315 |
+
class FeedForward(nn.Module):
|
316 |
+
def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
|
317 |
+
super().__init__()
|
318 |
+
activation = default(activation, nn.GELU)
|
319 |
+
|
320 |
+
self.glu = glu
|
321 |
+
self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
|
322 |
+
self.act = activation()
|
323 |
+
self.dropout = nn.Dropout(dropout)
|
324 |
+
self.w2 = nn.Linear(dim * mult, dim)
|
325 |
+
|
326 |
+
def forward(self, x, **kwargs):
|
327 |
+
if not self.glu:
|
328 |
+
x = self.w1(x)
|
329 |
+
x = self.act(x)
|
330 |
+
else:
|
331 |
+
x, v = self.w1(x).chunk(2, dim=-1)
|
332 |
+
x = self.act(x) * v
|
333 |
+
|
334 |
+
x = self.dropout(x)
|
335 |
+
x = self.w2(x)
|
336 |
+
return x
|
337 |
+
|
338 |
+
class SelfAttention(nn.Module):
|
339 |
+
def __init__(
|
340 |
+
self,
|
341 |
+
dim,
|
342 |
+
causal = False,
|
343 |
+
heads = 8,
|
344 |
+
dim_head = 64,
|
345 |
+
local_heads = 0,
|
346 |
+
local_window_size = 256,
|
347 |
+
nb_features = None,
|
348 |
+
feature_redraw_interval = 1000,
|
349 |
+
generalized_attention = False,
|
350 |
+
kernel_fn = nn.ReLU(),
|
351 |
+
dropout = 0.,
|
352 |
+
no_projection = False,
|
353 |
+
qkv_bias = False
|
354 |
+
):
|
355 |
+
super().__init__()
|
356 |
+
assert dim % heads == 0, 'dimension must be divisible by number of heads'
|
357 |
+
dim_head = default(dim_head, dim // heads)
|
358 |
+
inner_dim = dim_head * heads
|
359 |
+
self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection)
|
360 |
+
|
361 |
+
self.heads = heads
|
362 |
+
self.global_heads = heads - local_heads
|
363 |
+
self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
|
364 |
+
|
365 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias)
|
366 |
+
self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias)
|
367 |
+
self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias)
|
368 |
+
self.to_out = nn.Linear(inner_dim, dim)
|
369 |
+
self.dropout = nn.Dropout(dropout)
|
370 |
+
|
371 |
+
def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, output_attentions = False, **kwargs):
|
372 |
+
b, n, _, h, gh = *x.shape, self.heads, self.global_heads
|
373 |
+
|
374 |
+
cross_attend = exists(context)
|
375 |
+
|
376 |
+
context = default(context, x)
|
377 |
+
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
378 |
+
|
379 |
+
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
380 |
+
|
381 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
|
382 |
+
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
383 |
+
|
384 |
+
attn_outs = []
|
385 |
+
|
386 |
+
if not empty(q):
|
387 |
+
if exists(context_mask):
|
388 |
+
global_mask = context_mask[:, None, :, None]
|
389 |
+
v.masked_fill_(~global_mask, 0.)
|
390 |
+
|
391 |
+
if exists(pos_emb) and not cross_attend:
|
392 |
+
q, k, = apply_rotary_pos_emb(q, k, pos_emb)
|
393 |
+
|
394 |
+
if output_attentions:
|
395 |
+
out, attn_weights = self.fast_attention(q, k, v, output_attentions)
|
396 |
+
else:
|
397 |
+
out = self.fast_attention(q, k, v)
|
398 |
+
attn_outs.append(out)
|
399 |
+
|
400 |
+
if not empty(lq):
|
401 |
+
assert not cross_attend, 'local attention is not compatible with cross attention'
|
402 |
+
out = self.local_attn(lq, lk, lv, input_mask = mask)
|
403 |
+
attn_outs.append(out)
|
404 |
+
|
405 |
+
out = torch.cat(attn_outs, dim = 1) # combine attn_out and cross_attn_out, here we have only attn_out, that means this line does nothing
|
406 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
407 |
+
out = self.to_out(out)
|
408 |
+
if output_attentions:
|
409 |
+
return self.dropout(out), attn_weights
|
410 |
+
else:
|
411 |
+
return self.dropout(out)
|
412 |
+
|
413 |
+
# positional embeddings
|
414 |
+
|
415 |
+
class AbsolutePositionalEmbedding(nn.Module):
|
416 |
+
def __init__(self, dim, max_seq_len):
|
417 |
+
super().__init__()
|
418 |
+
self.emb = nn.Embedding(max_seq_len, dim)
|
419 |
+
|
420 |
+
def forward(self, x):
|
421 |
+
t = torch.arange(x.shape[1], device=x.device)
|
422 |
+
return self.emb(t)
|
423 |
+
|
424 |
+
# rotary positional embedding helpers
|
425 |
+
|
426 |
+
def rotate_every_two(x):
|
427 |
+
x = rearrange(x, '... (d j) -> ... d j', j = 2)
|
428 |
+
x1, x2 = x.unbind(dim = -1)
|
429 |
+
x = torch.stack((-x2, x1), dim = -1)
|
430 |
+
return rearrange(x, '... d j -> ... (d j)')
|
431 |
+
|
432 |
+
def apply_rotary_pos_emb(q, k, sinu_pos):
|
433 |
+
sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
|
434 |
+
sin, cos = sinu_pos.unbind(dim = -2)
|
435 |
+
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
|
436 |
+
q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
|
437 |
+
return q, k
|
438 |
+
|
439 |
+
# sinusoidal positional embeddings
|
440 |
+
|
441 |
+
class Gene2VecPositionalEmbedding(nn.Module):
|
442 |
+
def __init__(self, dim, max_seq_len):
|
443 |
+
super().__init__()
|
444 |
+
gene2vec_weight = np.load('./data/gene2vec_16906.npy')
|
445 |
+
gene2vec_weight = np.concatenate((gene2vec_weight, np.zeros((1, gene2vec_weight.shape[1]))), axis=0)
|
446 |
+
gene2vec_weight = torch.from_numpy(gene2vec_weight)
|
447 |
+
self.emb = nn.Embedding.from_pretrained(gene2vec_weight)
|
448 |
+
|
449 |
+
def forward(self, x):
|
450 |
+
t = torch.arange(x.shape[1], device=x.device)
|
451 |
+
return self.emb(t)
|
452 |
+
|
453 |
+
# performer
|
454 |
+
|
455 |
+
class Performer(nn.Module):
|
456 |
+
def __init__(
|
457 |
+
self,
|
458 |
+
dim, # dimension
|
459 |
+
depth, # layers
|
460 |
+
heads, # heads
|
461 |
+
dim_head, # dim of head
|
462 |
+
local_attn_heads = 0, # num of local attention heads, (heads - local_attn_heads) is num of global performers
|
463 |
+
local_window_size = 256, # window size of local attention
|
464 |
+
causal = False, # autoregressive or not
|
465 |
+
ff_mult = 4, # dim of intermediate features after attention / dim of input features
|
466 |
+
nb_features = None, # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head ?? what is random feature ??
|
467 |
+
feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
|
468 |
+
reversible = False, # reversible layers, from Reformer (save memory)
|
469 |
+
ff_chunks = 1, # chunk feedforward layer, from Reformer
|
470 |
+
generalized_attention = False, # defaults to softmax approximation, but can be set to True for generalized attention ?? what is generalized attention ??
|
471 |
+
kernel_fn = nn.ReLU(), # the kernel function to be used, if generalized attention is turned on, defaults to Relu
|
472 |
+
use_scalenorm = False, # use scale norm, from 'Transformers without Tears' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm
|
473 |
+
use_rezero = False, # use Rezero or not, from 'Rezero is all you need' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm
|
474 |
+
ff_glu = False, # use GLU (Gated Linear Units) variant for feedforward
|
475 |
+
ff_dropout = 0., # feedforward dropout
|
476 |
+
attn_dropout = 0., # post-attention dropout
|
477 |
+
cross_attend = False, # ??
|
478 |
+
no_projection = False, # ??
|
479 |
+
auto_check_redraw = True, # ??
|
480 |
+
qkv_bias = True, # ??
|
481 |
+
):
|
482 |
+
super().__init__()
|
483 |
+
layers = nn.ModuleList([])
|
484 |
+
local_attn_heads = cast_tuple(local_attn_heads)
|
485 |
+
local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads
|
486 |
+
assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth'
|
487 |
+
assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads'
|
488 |
+
|
489 |
+
if use_scalenorm:
|
490 |
+
wrapper_fn = partial(PreScaleNorm, dim)
|
491 |
+
elif use_rezero:
|
492 |
+
wrapper_fn = ReZero
|
493 |
+
else:
|
494 |
+
wrapper_fn = partial(PreLayerNorm, dim)
|
495 |
+
|
496 |
+
for _, local_heads in zip(range(depth), local_attn_heads):
|
497 |
+
layers.append(nn.ModuleList([
|
498 |
+
wrapper_fn(SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias)),
|
499 |
+
wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
|
500 |
+
]))
|
501 |
+
# if no need cross_attend(decoder), begin next cycle
|
502 |
+
if not cross_attend:
|
503 |
+
continue
|
504 |
+
layers.append(nn.ModuleList([
|
505 |
+
wrapper_fn(SelfAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection)),
|
506 |
+
wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
|
507 |
+
]))
|
508 |
+
|
509 |
+
execute_type = ReversibleSequence if reversible else SequentialSequence
|
510 |
+
|
511 |
+
route_attn = ((True, False),) * depth * (2 if cross_attend else 1) # ((True, False), (True, False), (True, False), (True, False), (True, False), (True, False))
|
512 |
+
route_context = ((False, False), (True, False)) * depth
|
513 |
+
attn_route_map = {'mask': route_attn, 'pos_emb': route_attn}
|
514 |
+
context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {}
|
515 |
+
self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map})
|
516 |
+
|
517 |
+
# keeping track of when to redraw projections for all attention layers
|
518 |
+
self.auto_check_redraw = auto_check_redraw
|
519 |
+
self.feature_redraw_interval = feature_redraw_interval
|
520 |
+
self.register_buffer('calls_since_last_redraw', torch.tensor(0))
|
521 |
+
|
522 |
+
def fix_projection_matrices_(self):
|
523 |
+
self.feature_redraw_interval = None
|
524 |
+
|
525 |
+
def check_redraw_projections(self):
|
526 |
+
if not self.training:
|
527 |
+
return
|
528 |
+
|
529 |
+
if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
|
530 |
+
device = get_module_device(self)
|
531 |
+
|
532 |
+
fast_attentions = find_modules(self, FastAttention)
|
533 |
+
for fast_attention in fast_attentions:
|
534 |
+
fast_attention.redraw_projection_matrix(device)
|
535 |
+
|
536 |
+
self.calls_since_last_redraw.zero_()
|
537 |
+
return
|
538 |
+
|
539 |
+
self.calls_since_last_redraw += 1
|
540 |
+
|
541 |
+
def forward(self, x, output_attentions = False, **kwargs):
|
542 |
+
if self.auto_check_redraw:
|
543 |
+
self.check_redraw_projections()
|
544 |
+
return self.net(x, output_attentions = output_attentions, **kwargs)
|
545 |
+
|
546 |
+
class PerformerLM(nn.Module):
|
547 |
+
def __init__(
|
548 |
+
self,
|
549 |
+
*,
|
550 |
+
num_tokens, # num of tokens
|
551 |
+
max_seq_len, # max length of sequence
|
552 |
+
dim, # dim of tokens
|
553 |
+
depth, # layers
|
554 |
+
heads, # num of heads
|
555 |
+
dim_head = 64, # dim of heads
|
556 |
+
local_attn_heads = 0,
|
557 |
+
local_window_size = 256,
|
558 |
+
causal = False,
|
559 |
+
ff_mult = 4,
|
560 |
+
nb_features = None,
|
561 |
+
feature_redraw_interval = 1000,
|
562 |
+
reversible = False,
|
563 |
+
ff_chunks = 1,
|
564 |
+
ff_glu = False,
|
565 |
+
emb_dropout = 0.,
|
566 |
+
ff_dropout = 0.,
|
567 |
+
attn_dropout = 0.,
|
568 |
+
generalized_attention = False,
|
569 |
+
kernel_fn = nn.ReLU(),
|
570 |
+
use_scalenorm = False,
|
571 |
+
use_rezero = False,
|
572 |
+
cross_attend = False,
|
573 |
+
no_projection = False,
|
574 |
+
tie_embed = False, # False: output is num of tokens, True: output is dim of tokens //multiply final embeddings with token weights for logits, like gpt decoder//
|
575 |
+
g2v_position_emb = True, # priority: gene2vec, no embedding
|
576 |
+
auto_check_redraw = True,
|
577 |
+
qkv_bias = False
|
578 |
+
):
|
579 |
+
super().__init__()
|
580 |
+
local_attn_heads = cast_tuple(local_attn_heads)
|
581 |
+
|
582 |
+
self.max_seq_len = max_seq_len
|
583 |
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
584 |
+
|
585 |
+
if g2v_position_emb:
|
586 |
+
self.pos_emb = Gene2VecPositionalEmbedding(dim, max_seq_len)
|
587 |
+
self.layer_pos_emb = Always(None)
|
588 |
+
else:
|
589 |
+
self.pos_emb = torch.zeros_like
|
590 |
+
self.layer_pos_emb = Always(None)
|
591 |
+
|
592 |
+
self.dropout = nn.Dropout(emb_dropout)
|
593 |
+
|
594 |
+
self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias)
|
595 |
+
self.norm = nn.LayerNorm(dim)
|
596 |
+
self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None
|
597 |
+
|
598 |
+
def check_redraw_projections(self):
|
599 |
+
self.performer.check_redraw_projections()
|
600 |
+
|
601 |
+
def fix_projection_matrices_(self):
|
602 |
+
self.performer.fix_projection_matrices_()
|
603 |
+
|
604 |
+
def forward(self, x, return_encodings = False, output_attentions = False, **kwargs):
|
605 |
+
b, n, device = *x.shape, x.device
|
606 |
+
assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}'
|
607 |
+
|
608 |
+
#pdb.set_trace()
|
609 |
+
# token and positional embedding
|
610 |
+
x = self.token_emb(x)
|
611 |
+
if output_attentions:
|
612 |
+
x.requires_grad_() # used for attn_map output
|
613 |
+
x += self.pos_emb(x)
|
614 |
+
x = self.dropout(x)
|
615 |
+
|
616 |
+
# performer layers
|
617 |
+
layer_pos_emb = self.layer_pos_emb(x)
|
618 |
+
|
619 |
+
if output_attentions:
|
620 |
+
x, attn_weights = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs)
|
621 |
+
# norm and to logits
|
622 |
+
x = self.norm(x)
|
623 |
+
if return_encodings:
|
624 |
+
return x, attn_weights
|
625 |
+
|
626 |
+
if exists(self.to_out):
|
627 |
+
return self.to_out(x), attn_weights
|
628 |
+
|
629 |
+
return (x @ self.token_emb.weight.t()), attn_weights
|
630 |
+
else:
|
631 |
+
x = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs)
|
632 |
+
|
633 |
+
# norm and to logits
|
634 |
+
x = self.norm(x)
|
635 |
+
if return_encodings:
|
636 |
+
return x
|
637 |
+
|
638 |
+
if exists(self.to_out):
|
639 |
+
x = self.to_out(x)
|
640 |
+
return x
|
641 |
+
|
642 |
+
return x @ self.token_emb.weight.t()
|
performer_pytorch/reversible.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from operator import itemgetter
|
4 |
+
from torch.autograd.function import Function
|
5 |
+
from torch.utils.checkpoint import get_device_states, set_device_states
|
6 |
+
|
7 |
+
# for routing arguments into the functions of the reversible layer
|
8 |
+
def route_args(router, args, depth):
|
9 |
+
routed_args = [(dict(), dict()) for _ in range(depth)]
|
10 |
+
matched_keys = [key for key in args.keys() if key in router]
|
11 |
+
|
12 |
+
for key in matched_keys:
|
13 |
+
val = args[key]
|
14 |
+
for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
|
15 |
+
new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
|
16 |
+
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
|
17 |
+
return routed_args
|
18 |
+
|
19 |
+
# following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
|
20 |
+
class Deterministic(nn.Module):
|
21 |
+
def __init__(self, net):
|
22 |
+
super().__init__()
|
23 |
+
self.net = net
|
24 |
+
self.cpu_state = None
|
25 |
+
self.cuda_in_fwd = None
|
26 |
+
self.gpu_devices = None
|
27 |
+
self.gpu_states = None
|
28 |
+
|
29 |
+
def record_rng(self, *args):
|
30 |
+
self.cpu_state = torch.get_rng_state()
|
31 |
+
if torch.cuda._initialized:
|
32 |
+
self.cuda_in_fwd = True
|
33 |
+
self.gpu_devices, self.gpu_states = get_device_states(*args)
|
34 |
+
|
35 |
+
def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
|
36 |
+
if record_rng:
|
37 |
+
self.record_rng(*args)
|
38 |
+
|
39 |
+
if not set_rng:
|
40 |
+
return self.net(*args, **kwargs)
|
41 |
+
|
42 |
+
rng_devices = []
|
43 |
+
if self.cuda_in_fwd:
|
44 |
+
rng_devices = self.gpu_devices
|
45 |
+
|
46 |
+
with torch.random.fork_rng(devices=rng_devices, enabled=True):
|
47 |
+
torch.set_rng_state(self.cpu_state)
|
48 |
+
if self.cuda_in_fwd:
|
49 |
+
set_device_states(self.gpu_devices, self.gpu_states)
|
50 |
+
return self.net(*args, **kwargs)
|
51 |
+
|
52 |
+
# heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
|
53 |
+
# once multi-GPU is confirmed working, refactor and send PR back to source
|
54 |
+
class ReversibleBlock(nn.Module):
|
55 |
+
def __init__(self, f, g):
|
56 |
+
super().__init__()
|
57 |
+
self.f = Deterministic(f)
|
58 |
+
self.g = Deterministic(g)
|
59 |
+
|
60 |
+
def forward(self, x, f_args = {}, g_args = {}):
|
61 |
+
x1, x2 = torch.chunk(x, 2, dim=2)
|
62 |
+
y1, y2 = None, None
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
|
66 |
+
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
|
67 |
+
|
68 |
+
return torch.cat([y1, y2], dim=2)
|
69 |
+
|
70 |
+
def backward_pass(self, y, dy, f_args = {}, g_args = {}):
|
71 |
+
y1, y2 = torch.chunk(y, 2, dim=2)
|
72 |
+
del y
|
73 |
+
|
74 |
+
dy1, dy2 = torch.chunk(dy, 2, dim=2)
|
75 |
+
del dy
|
76 |
+
|
77 |
+
with torch.enable_grad():
|
78 |
+
y1.requires_grad = True
|
79 |
+
gy1 = self.g(y1, set_rng=True, **g_args)
|
80 |
+
torch.autograd.backward(gy1, dy2)
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
x2 = y2 - gy1
|
84 |
+
del y2, gy1
|
85 |
+
|
86 |
+
dx1 = dy1 + y1.grad
|
87 |
+
del dy1
|
88 |
+
y1.grad = None
|
89 |
+
|
90 |
+
with torch.enable_grad():
|
91 |
+
x2.requires_grad = True
|
92 |
+
fx2 = self.f(x2, set_rng=True, **f_args)
|
93 |
+
torch.autograd.backward(fx2, dx1, retain_graph=True)
|
94 |
+
|
95 |
+
with torch.no_grad():
|
96 |
+
x1 = y1 - fx2
|
97 |
+
del y1, fx2
|
98 |
+
|
99 |
+
dx2 = dy2 + x2.grad
|
100 |
+
del dy2
|
101 |
+
x2.grad = None
|
102 |
+
|
103 |
+
x = torch.cat([x1, x2.detach()], dim=2)
|
104 |
+
dx = torch.cat([dx1, dx2], dim=2)
|
105 |
+
|
106 |
+
return x, dx
|
107 |
+
|
108 |
+
class _ReversibleFunction(Function):
|
109 |
+
@staticmethod
|
110 |
+
def forward(ctx, x, blocks, args):
|
111 |
+
ctx.args = args
|
112 |
+
for block, kwarg in zip(blocks, args):
|
113 |
+
x = block(x, **kwarg)
|
114 |
+
ctx.y = x.detach()
|
115 |
+
ctx.blocks = blocks
|
116 |
+
return x
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def backward(ctx, dy):
|
120 |
+
y = ctx.y
|
121 |
+
args = ctx.args
|
122 |
+
for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
|
123 |
+
y, dy = block.backward_pass(y, dy, **kwargs)
|
124 |
+
return dy, None, None
|
125 |
+
|
126 |
+
class SequentialSequence(nn.Module):
|
127 |
+
def __init__(self, layers, args_route = {}):
|
128 |
+
super().__init__()
|
129 |
+
assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
|
130 |
+
self.layers = layers
|
131 |
+
self.args_route = args_route
|
132 |
+
|
133 |
+
def forward(self, x, output_attentions = False, **kwargs):
|
134 |
+
args = route_args(self.args_route, kwargs, len(self.layers))
|
135 |
+
layers_and_args = list(zip(self.layers, args))
|
136 |
+
|
137 |
+
if output_attentions:
|
138 |
+
attn_weights = []
|
139 |
+
for (f, g), (f_args, g_args) in layers_and_args:
|
140 |
+
if output_attentions:
|
141 |
+
x = x + f(x, output_attentions = output_attentions, **f_args)[0]
|
142 |
+
attn_weights.append(f(x, output_attentions = output_attentions, **f_args)[1].unsqueeze(0))
|
143 |
+
else:
|
144 |
+
x = x + f(x, **f_args)
|
145 |
+
x = x + g(x, **g_args)
|
146 |
+
if output_attentions:
|
147 |
+
attn_weights = torch.transpose(torch.cat(attn_weights, dim=0), 0, 1) # the final dim is (batch, layer, head, len, len)
|
148 |
+
attn_weights = torch.mean(attn_weights, dim=1) # the dim is (batch, head, len, len)
|
149 |
+
return x, attn_weights
|
150 |
+
else:
|
151 |
+
return x
|
152 |
+
|
153 |
+
class ReversibleSequence(nn.Module):
|
154 |
+
def __init__(self, blocks, args_route = {}):
|
155 |
+
super().__init__()
|
156 |
+
self.args_route = args_route
|
157 |
+
self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
|
158 |
+
|
159 |
+
def forward(self, x, **kwargs):
|
160 |
+
x = torch.cat([x, x], dim=-1)
|
161 |
+
|
162 |
+
blocks = self.blocks
|
163 |
+
args = route_args(self.args_route, kwargs, len(blocks))
|
164 |
+
args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
|
165 |
+
|
166 |
+
out = _ReversibleFunction.apply(x, blocks, args)
|
167 |
+
return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)
|