account18hackathon commited on
Commit
26a9f0f
1 Parent(s): 4e40454

Upload 4 files

Browse files
performer_pytorch/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from performer_pytorch.performer_pytorch import PerformerLM, Performer, FastAttention, SelfAttention
performer_pytorch/autoregressive_wrapper.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils.rnn import pad_sequence
6
+
7
+ import pdb
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+ def top_p(logits, thres = 0.9):
14
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
15
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
16
+
17
+ sorted_indices_to_remove = cum_probs > (1 - thres)
18
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
19
+ sorted_indices_to_remove[:, 0] = 0
20
+
21
+ sorted_logits[sorted_indices_to_remove] = float('-inf')
22
+ return sorted_logits.scatter(1, sorted_indices, sorted_logits)
23
+
24
+ def top_k(logits, thres = 0.9):
25
+ k = int((1 - thres) * logits.shape[-1])
26
+ val, ind = torch.topk(logits, k)
27
+ probs = torch.full_like(logits, float('-inf'))
28
+ probs.scatter_(1, ind, val)
29
+ return probs
30
+
31
+ def repetition_penalty_fn(logits, ctx, theta=1.2):
32
+ w = torch.ones(logits.shape[-1], dtype=torch.float, device=logits.device)
33
+ for i in torch.unique(ctx):
34
+ w[i] = theta
35
+ return logits/w
36
+
37
+ class AutoregressiveWrapper(nn.Module):
38
+ def __init__(self, net, ignore_index = 0, pad_value = 0):
39
+ super().__init__()
40
+ self.pad_value = pad_value
41
+ self.ignore_index = ignore_index
42
+
43
+ self.net = net
44
+ self.max_seq_len = net.max_seq_len
45
+
46
+ @torch.no_grad()
47
+ def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, repetition_penalty=1.0, repetition_penalty_ctx=32, **kwargs):
48
+ was_training = self.net.training
49
+ num_dims = len(start_tokens.shape)
50
+
51
+ if num_dims == 1:
52
+ start_tokens = start_tokens[None, :]
53
+
54
+ b, t = start_tokens.shape
55
+
56
+ self.net.eval()
57
+ out = start_tokens
58
+ input_mask = kwargs.pop('mask', None)
59
+
60
+ if input_mask is None:
61
+ input_mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)
62
+
63
+ # in case of conditional generation, if enc_mask is not provided use the correct context_mask
64
+ context_mask = kwargs.pop('context_mask', None)
65
+
66
+ if 'context' in kwargs and not exists(context_mask):
67
+ context = kwargs['context']
68
+ context_mask = torch.full(context.shape[:2], True, dtype=torch.bool, device=out.device)
69
+
70
+ kwargs.update(context_mask = context_mask)
71
+
72
+ for _ in range(seq_len):
73
+ x = out[:, -self.max_seq_len:]
74
+ input_mask = input_mask[:, -self.max_seq_len:]
75
+ logits = self.net(x, mask=input_mask, **kwargs)[:, -1, :]
76
+ if repetition_penalty > 1.0:
77
+ logits = repetition_penalty_fn(logits, out[-repetition_penalty_ctx:], theta=repetition_penalty)
78
+ filtered_logits = filter_logits_fn(logits, thres = filter_thres)
79
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
80
+ sample = torch.multinomial(probs, 1)
81
+
82
+ out = torch.cat((out, sample), dim=-1)
83
+ input_mask = F.pad(input_mask, (0, 1), value=True)
84
+
85
+ if eos_token is not None and (sample == eos_token).all():
86
+ break
87
+
88
+ out = out[:, t:]
89
+
90
+ if num_dims == 1:
91
+ out = out.squeeze(0)
92
+
93
+ self.net.train(was_training)
94
+ return out
95
+
96
+ def forward(self, x, **kwargs):
97
+ xi = x[:, :-1]
98
+ xo = x[:, 1:]
99
+
100
+ # help auto-solve an area of confusion around input masks in auto-regressive
101
+ # if user supplies a mask that is only off by one from the source sequence, resolve it for them
102
+ mask = kwargs.pop('mask', None)
103
+ if mask is not None and mask.shape[1] == x.shape[1]:
104
+ mask = mask[:, :-1]
105
+ kwargs.update(mask = mask)
106
+
107
+ out = self.net(xi, **kwargs)
108
+
109
+ loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
110
+
111
+ #pdb.set_trace()
112
+
113
+ return loss
performer_pytorch/performer_pytorch.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torch.cuda.amp import autocast
7
+ from einops import rearrange, repeat
8
+
9
+ from functools import partial
10
+ from contextlib import contextmanager
11
+
12
+ from local_attention import LocalAttention
13
+ from performer_pytorch.reversible import ReversibleSequence, SequentialSequence
14
+
15
+ import pdb
16
+
17
+ try:
18
+ from apex import amp
19
+ APEX_AVAILABLE = True
20
+ except:
21
+ APEX_AVAILABLE = False
22
+
23
+ # helpers
24
+
25
+ def exists(val):
26
+ return val is not None
27
+
28
+ def empty(tensor):
29
+ return tensor.numel() == 0
30
+
31
+ def default(val, d):
32
+ return val if exists(val) else d
33
+
34
+ @contextmanager
35
+ def null_context():
36
+ yield
37
+
38
+ def cast_tuple(val):
39
+ return (val,) if not isinstance(val, tuple) else val
40
+
41
+ # def get_module_device(module):
42
+ # return next(module.parameters).device
43
+
44
+ def get_module_device(module):
45
+ try:
46
+ return next(module.parameters()).device
47
+ except StopIteration:
48
+ # For nn.DataParallel compatibility in PyTorch 1.5
49
+ def find_tensor_attributes(module):
50
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
51
+ return tuples
52
+ gen = module._named_members(get_members_fn=find_tensor_attributes)
53
+ first_tuple = next(gen)
54
+ return first_tuple[1].device
55
+
56
+ def find_modules(nn_module, type):
57
+ return [module for module in nn_module.modules() if isinstance(module, type)]
58
+
59
+ class Always(nn.Module):
60
+ def __init__(self, val):
61
+ super().__init__()
62
+ self.val = val
63
+
64
+ def forward(self, *args, **kwargs):
65
+ return self.val
66
+
67
+ # kernel functions
68
+
69
+ # transcribed from jax to pytorch from
70
+ # https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py
71
+
72
+ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
73
+ b, h, *_ = data.shape
74
+
75
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
76
+
77
+ ratio = (projection_matrix.shape[0] ** -0.5)
78
+
79
+ projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
80
+ projection = projection.type_as(data)
81
+
82
+ data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
83
+
84
+ diag_data = data ** 2
85
+ diag_data = torch.sum(diag_data, dim=-1)
86
+ diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
87
+ diag_data = diag_data.unsqueeze(dim=-1)
88
+
89
+ if is_query:
90
+ data_dash = ratio * (
91
+ torch.exp(data_dash - diag_data -
92
+ torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
93
+ else:
94
+ data_dash = ratio * (
95
+ torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
96
+
97
+ return data_dash.type_as(data)
98
+
99
+ def generalized_kernel(data, *, projection_matrix, kernel_fn = nn.ReLU(), kernel_epsilon = 0.001, normalize_data = True, device = None):
100
+ b, h, *_ = data.shape
101
+
102
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
103
+
104
+ if projection_matrix is None:
105
+ return kernel_fn(data_normalizer * data) + kernel_epsilon
106
+
107
+ projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
108
+ projection = projection.type_as(data)
109
+
110
+ data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
111
+
112
+ data_prime = kernel_fn(data_dash) + kernel_epsilon
113
+ return data_prime.type_as(data)
114
+
115
+ def orthogonal_matrix_chunk(cols, device = None):
116
+ unstructured_block = torch.randn((cols, cols), device = device)
117
+ q, r = torch.qr(unstructured_block.cpu(), some = True)
118
+ q, r = map(lambda t: t.to(device), (q, r))
119
+ return q.t()
120
+
121
+ def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, device = None):
122
+ nb_full_blocks = int(nb_rows / nb_columns)
123
+
124
+ block_list = []
125
+
126
+ for _ in range(nb_full_blocks):
127
+ q = orthogonal_matrix_chunk(nb_columns, device = device)
128
+ block_list.append(q)
129
+
130
+ remaining_rows = nb_rows - nb_full_blocks * nb_columns
131
+ if remaining_rows > 0:
132
+ q = orthogonal_matrix_chunk(nb_columns, device = device)
133
+ block_list.append(q[:remaining_rows])
134
+
135
+ final_matrix = torch.cat(block_list)
136
+
137
+ if scaling == 0:
138
+ multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
139
+ elif scaling == 1:
140
+ multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
141
+ else:
142
+ raise ValueError(f'Invalid scaling {scaling}')
143
+
144
+ return torch.diag(multiplier) @ final_matrix
145
+
146
+ # linear attention classes with softmax kernel
147
+
148
+ # non-causal linear attention
149
+ def linear_attention(q, k, v):
150
+ k_cumsum = k.sum(dim = -2)
151
+ D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
152
+ context = torch.einsum('...nd,...ne->...de', k, v)
153
+ out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
154
+ return out
155
+
156
+ # efficient causal linear attention, created by EPFL
157
+ # TODO: rewrite EPFL's CUDA kernel to do mixed precision and remove half to float conversion and back
158
+ def causal_linear_attention(q, k, v, eps = 1e-6):
159
+ from fast_transformers.causal_product import CausalDotProduct
160
+ autocast_enabled = torch.is_autocast_enabled()
161
+ is_half = isinstance(q, torch.cuda.HalfTensor)
162
+ assert not is_half or APEX_AVAILABLE, 'half tensors can only be used if nvidia apex is available'
163
+ cuda_context = null_context if not autocast_enabled else partial(autocast, enabled = False)
164
+
165
+ causal_dot_product_fn = amp.float_function(CausalDotProduct.apply) if is_half else CausalDotProduct.apply
166
+
167
+ k_cumsum = k.cumsum(dim=-2) + eps
168
+ D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))
169
+
170
+ with cuda_context():
171
+ if autocast_enabled:
172
+ q, k, v = map(lambda t: t.float(), (q, k, v))
173
+
174
+ out = causal_dot_product_fn(q, k, v)
175
+
176
+ out = torch.einsum('...nd,...n->...nd', out, D_inv)
177
+ return out
178
+
179
+ # inefficient causal linear attention, without cuda code, for reader's reference
180
+ # not being used
181
+ def causal_linear_attention_noncuda(q, k, v, chunk_size = 128):
182
+ last_k_cumsum = 0
183
+ last_context_cumsum = 0
184
+ outs = []
185
+
186
+ for q, k, v in zip(*map(lambda t: t.chunk(chunk_size, dim = -2), (q, k, v))):
187
+ k_cumsum = last_k_cumsum + k.cumsum(dim=-2)
188
+
189
+ D_inv = 1. / torch.einsum('...nd,...nd->...n', q, k_cumsum.type_as(q))
190
+ context = torch.einsum('...nd,...ne->...nde', k, v)
191
+ context_cumsum = last_context_cumsum + context.cumsum(dim=-3)
192
+ out = torch.einsum('...nde,...nd,...n->...ne', context_cumsum, q, D_inv)
193
+
194
+ last_k_cumsum = k_cumsum[:, :, -1:]
195
+ last_context_cumsum = context_cumsum[:, :, -1:]
196
+ outs.append(out)
197
+
198
+ return torch.cat(outs, dim = -2)
199
+
200
+ def norm_tensor(tensor, dim=-1):
201
+ return tensor / tensor.sum(dim=dim).unsqueeze(dim)
202
+
203
+ class FastAttention(nn.Module):
204
+ def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), no_projection = False):
205
+ super().__init__()
206
+ nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
207
+
208
+ self.dim_heads = dim_heads
209
+ self.nb_features = nb_features
210
+ self.ortho_scaling = ortho_scaling
211
+
212
+ self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling)
213
+ projection_matrix = self.create_projection()
214
+ self.register_buffer('projection_matrix', projection_matrix)
215
+
216
+ self.generalized_attention = generalized_attention
217
+ self.kernel_fn = kernel_fn
218
+
219
+ # if this is turned on, no projection will be used
220
+ # queries and keys will be softmax-ed as in the original efficient attention paper
221
+ self.no_projection = no_projection
222
+
223
+ self.causal = causal
224
+ if causal:
225
+ try:
226
+ import fast_transformers.causal_product.causal_product_cuda
227
+ self.causal_linear_fn = partial(causal_linear_attention)
228
+ except ImportError:
229
+ print('unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version')
230
+ self.causal_linear_fn = causal_linear_attention_noncuda
231
+
232
+ @torch.no_grad()
233
+ def redraw_projection_matrix(self, device):
234
+ projections = self.create_projection(device = device)
235
+ self.projection_matrix.copy_(projections)
236
+ del projections
237
+
238
+ def forward(self, q, k, v, output_attentions = False):
239
+ device = q.device
240
+ # inds = [8060, 8064, 6243, 8575, 10342, 10913, 9366, 993, 7796, 5210, 5212, 5504, 6851, 6559, 5508, 13107, 13820]
241
+ if self.no_projection:
242
+ q = q.softmax(dim = -1)
243
+ k = torch.exp(k) if self.causal else k.softmax(dim = -2)
244
+
245
+ elif self.generalized_attention:
246
+ create_kernel = partial(generalized_kernel, kernel_fn = self.kernel_fn, projection_matrix = self.projection_matrix, device = device)
247
+ q, k = map(create_kernel, (q, k))
248
+
249
+ else:
250
+ create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
251
+ q = create_kernel(q, is_query = True)
252
+ k = create_kernel(k, is_query = False)
253
+
254
+ attn_fn = linear_attention if not self.causal else self.causal_linear_fn
255
+ out = attn_fn(q, k, v)
256
+ if output_attentions:
257
+ v_diag = torch.eye(v.shape[-2]).to(device)
258
+ v_diag = v_diag.unsqueeze(0).unsqueeze(0).repeat(v.shape[0],v.shape[1],1,1)
259
+ # attn_weights = torch.zeros(1, 1, len(inds), len(inds)).to(device).to(torch.float16)
260
+ # attn_weights = torch.zeros(1, q.shape[1], len(inds), len(inds)).to(device).to(torch.float16)
261
+ attn_weights = torch.zeros(1, 1, q.shape[2], q.shape[2]).to(device).to(torch.float16)
262
+ for head_dim in range(q.shape[1]):
263
+ # attn_weights[0, head_dim] = torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16)))[0, inds][:, inds]
264
+ attn_weights += torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16)))
265
+ # attn_weights += norm_tensor(torch.abs(attn_fn(q[:,head_dim].to(torch.float16), k[:,head_dim].to(torch.float16), v_diag[:,head_dim].to(torch.float16))), dim=-1)
266
+ attn_weights /= q.shape[1]
267
+ return out, attn_weights
268
+ else:
269
+ return out
270
+
271
+ # classes
272
+
273
+ class ReZero(nn.Module):
274
+ def __init__(self, fn):
275
+ super().__init__()
276
+ self.g = nn.Parameter(torch.tensor(1e-3))
277
+ self.fn = fn
278
+
279
+ def forward(self, x, **kwargs):
280
+ return self.fn(x, **kwargs) * self.g
281
+
282
+ class PreScaleNorm(nn.Module):
283
+ def __init__(self, dim, fn, eps=1e-5):
284
+ super().__init__()
285
+ self.fn = fn
286
+ self.g = nn.Parameter(torch.ones(1))
287
+ self.eps = eps
288
+
289
+ def forward(self, x, **kwargs):
290
+ n = torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
291
+ x = x / n * self.g
292
+ return self.fn(x, **kwargs)
293
+
294
+ class PreLayerNorm(nn.Module):
295
+ def __init__(self, dim, fn):
296
+ super().__init__()
297
+ self.norm = nn.LayerNorm(dim)
298
+ self.fn = fn
299
+ def forward(self, x, **kwargs):
300
+ return self.fn(self.norm(x), **kwargs)
301
+
302
+ class Chunk(nn.Module):
303
+ def __init__(self, chunks, fn, along_dim = -1):
304
+ super().__init__()
305
+ self.dim = along_dim
306
+ self.chunks = chunks
307
+ self.fn = fn
308
+
309
+ def forward(self, x, **kwargs):
310
+ if self.chunks == 1:
311
+ return self.fn(x, **kwargs)
312
+ chunks = x.chunk(self.chunks, dim = self.dim)
313
+ return torch.cat([self.fn(c, **kwargs) for c in chunks], dim = self.dim)
314
+
315
+ class FeedForward(nn.Module):
316
+ def __init__(self, dim, mult = 4, dropout = 0., activation = None, glu = False):
317
+ super().__init__()
318
+ activation = default(activation, nn.GELU)
319
+
320
+ self.glu = glu
321
+ self.w1 = nn.Linear(dim, dim * mult * (2 if glu else 1))
322
+ self.act = activation()
323
+ self.dropout = nn.Dropout(dropout)
324
+ self.w2 = nn.Linear(dim * mult, dim)
325
+
326
+ def forward(self, x, **kwargs):
327
+ if not self.glu:
328
+ x = self.w1(x)
329
+ x = self.act(x)
330
+ else:
331
+ x, v = self.w1(x).chunk(2, dim=-1)
332
+ x = self.act(x) * v
333
+
334
+ x = self.dropout(x)
335
+ x = self.w2(x)
336
+ return x
337
+
338
+ class SelfAttention(nn.Module):
339
+ def __init__(
340
+ self,
341
+ dim,
342
+ causal = False,
343
+ heads = 8,
344
+ dim_head = 64,
345
+ local_heads = 0,
346
+ local_window_size = 256,
347
+ nb_features = None,
348
+ feature_redraw_interval = 1000,
349
+ generalized_attention = False,
350
+ kernel_fn = nn.ReLU(),
351
+ dropout = 0.,
352
+ no_projection = False,
353
+ qkv_bias = False
354
+ ):
355
+ super().__init__()
356
+ assert dim % heads == 0, 'dimension must be divisible by number of heads'
357
+ dim_head = default(dim_head, dim // heads)
358
+ inner_dim = dim_head * heads
359
+ self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, no_projection = no_projection)
360
+
361
+ self.heads = heads
362
+ self.global_heads = heads - local_heads
363
+ self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
364
+
365
+ self.to_q = nn.Linear(dim, inner_dim, bias = qkv_bias)
366
+ self.to_k = nn.Linear(dim, inner_dim, bias = qkv_bias)
367
+ self.to_v = nn.Linear(dim, inner_dim, bias = qkv_bias)
368
+ self.to_out = nn.Linear(inner_dim, dim)
369
+ self.dropout = nn.Dropout(dropout)
370
+
371
+ def forward(self, x, pos_emb = None, context = None, mask = None, context_mask = None, output_attentions = False, **kwargs):
372
+ b, n, _, h, gh = *x.shape, self.heads, self.global_heads
373
+
374
+ cross_attend = exists(context)
375
+
376
+ context = default(context, x)
377
+ context_mask = default(context_mask, mask) if not cross_attend else context_mask
378
+
379
+ q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
380
+
381
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
382
+ (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
383
+
384
+ attn_outs = []
385
+
386
+ if not empty(q):
387
+ if exists(context_mask):
388
+ global_mask = context_mask[:, None, :, None]
389
+ v.masked_fill_(~global_mask, 0.)
390
+
391
+ if exists(pos_emb) and not cross_attend:
392
+ q, k, = apply_rotary_pos_emb(q, k, pos_emb)
393
+
394
+ if output_attentions:
395
+ out, attn_weights = self.fast_attention(q, k, v, output_attentions)
396
+ else:
397
+ out = self.fast_attention(q, k, v)
398
+ attn_outs.append(out)
399
+
400
+ if not empty(lq):
401
+ assert not cross_attend, 'local attention is not compatible with cross attention'
402
+ out = self.local_attn(lq, lk, lv, input_mask = mask)
403
+ attn_outs.append(out)
404
+
405
+ out = torch.cat(attn_outs, dim = 1) # combine attn_out and cross_attn_out, here we have only attn_out, that means this line does nothing
406
+ out = rearrange(out, 'b h n d -> b n (h d)')
407
+ out = self.to_out(out)
408
+ if output_attentions:
409
+ return self.dropout(out), attn_weights
410
+ else:
411
+ return self.dropout(out)
412
+
413
+ # positional embeddings
414
+
415
+ class AbsolutePositionalEmbedding(nn.Module):
416
+ def __init__(self, dim, max_seq_len):
417
+ super().__init__()
418
+ self.emb = nn.Embedding(max_seq_len, dim)
419
+
420
+ def forward(self, x):
421
+ t = torch.arange(x.shape[1], device=x.device)
422
+ return self.emb(t)
423
+
424
+ # rotary positional embedding helpers
425
+
426
+ def rotate_every_two(x):
427
+ x = rearrange(x, '... (d j) -> ... d j', j = 2)
428
+ x1, x2 = x.unbind(dim = -1)
429
+ x = torch.stack((-x2, x1), dim = -1)
430
+ return rearrange(x, '... d j -> ... (d j)')
431
+
432
+ def apply_rotary_pos_emb(q, k, sinu_pos):
433
+ sinu_pos = rearrange(sinu_pos, '() n (j d) -> n j d', j = 2)
434
+ sin, cos = sinu_pos.unbind(dim = -2)
435
+ sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
436
+ q, k = map(lambda t: (t * cos) + (rotate_every_two(t) * sin), (q, k))
437
+ return q, k
438
+
439
+ # sinusoidal positional embeddings
440
+
441
+ class Gene2VecPositionalEmbedding(nn.Module):
442
+ def __init__(self, dim, max_seq_len):
443
+ super().__init__()
444
+ gene2vec_weight = np.load('./data/gene2vec_16906.npy')
445
+ gene2vec_weight = np.concatenate((gene2vec_weight, np.zeros((1, gene2vec_weight.shape[1]))), axis=0)
446
+ gene2vec_weight = torch.from_numpy(gene2vec_weight)
447
+ self.emb = nn.Embedding.from_pretrained(gene2vec_weight)
448
+
449
+ def forward(self, x):
450
+ t = torch.arange(x.shape[1], device=x.device)
451
+ return self.emb(t)
452
+
453
+ # performer
454
+
455
+ class Performer(nn.Module):
456
+ def __init__(
457
+ self,
458
+ dim, # dimension
459
+ depth, # layers
460
+ heads, # heads
461
+ dim_head, # dim of head
462
+ local_attn_heads = 0, # num of local attention heads, (heads - local_attn_heads) is num of global performers
463
+ local_window_size = 256, # window size of local attention
464
+ causal = False, # autoregressive or not
465
+ ff_mult = 4, # dim of intermediate features after attention / dim of input features
466
+ nb_features = None, # number of random features, if not set, will default to (d * log(d)), where d is the dimension of each head ?? what is random feature ??
467
+ feature_redraw_interval = 1000, # how frequently to redraw the projection matrix, the more frequent, the slower the training
468
+ reversible = False, # reversible layers, from Reformer (save memory)
469
+ ff_chunks = 1, # chunk feedforward layer, from Reformer
470
+ generalized_attention = False, # defaults to softmax approximation, but can be set to True for generalized attention ?? what is generalized attention ??
471
+ kernel_fn = nn.ReLU(), # the kernel function to be used, if generalized attention is turned on, defaults to Relu
472
+ use_scalenorm = False, # use scale norm, from 'Transformers without Tears' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm
473
+ use_rezero = False, # use Rezero or not, from 'Rezero is all you need' paper, a substitute for LayerNorm, priority: scalenorm.rezero.layernorm
474
+ ff_glu = False, # use GLU (Gated Linear Units) variant for feedforward
475
+ ff_dropout = 0., # feedforward dropout
476
+ attn_dropout = 0., # post-attention dropout
477
+ cross_attend = False, # ??
478
+ no_projection = False, # ??
479
+ auto_check_redraw = True, # ??
480
+ qkv_bias = True, # ??
481
+ ):
482
+ super().__init__()
483
+ layers = nn.ModuleList([])
484
+ local_attn_heads = cast_tuple(local_attn_heads)
485
+ local_attn_heads = local_attn_heads * depth if len(local_attn_heads) == 1 else local_attn_heads
486
+ assert len(local_attn_heads) == depth, 'tuple specifying number of local attention heads per depth must be equal to the total depth'
487
+ assert all(map(lambda n: n >= 0 and n <= heads, local_attn_heads)), 'local attention head value must be less than the total number of heads'
488
+
489
+ if use_scalenorm:
490
+ wrapper_fn = partial(PreScaleNorm, dim)
491
+ elif use_rezero:
492
+ wrapper_fn = ReZero
493
+ else:
494
+ wrapper_fn = partial(PreLayerNorm, dim)
495
+
496
+ for _, local_heads in zip(range(depth), local_attn_heads):
497
+ layers.append(nn.ModuleList([
498
+ wrapper_fn(SelfAttention(dim, causal = causal, heads = heads, dim_head = dim_head, local_heads = local_heads, local_window_size = local_window_size, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection, qkv_bias = qkv_bias)),
499
+ wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
500
+ ]))
501
+ # if no need cross_attend(decoder), begin next cycle
502
+ if not cross_attend:
503
+ continue
504
+ layers.append(nn.ModuleList([
505
+ wrapper_fn(SelfAttention(dim, heads = heads, dim_head = dim_head, nb_features = nb_features, generalized_attention = generalized_attention, kernel_fn = kernel_fn, dropout = attn_dropout, no_projection = no_projection)),
506
+ wrapper_fn(Chunk(ff_chunks, FeedForward(dim, mult = ff_mult, dropout = ff_dropout, glu = ff_glu), along_dim = 1))
507
+ ]))
508
+
509
+ execute_type = ReversibleSequence if reversible else SequentialSequence
510
+
511
+ route_attn = ((True, False),) * depth * (2 if cross_attend else 1) # ((True, False), (True, False), (True, False), (True, False), (True, False), (True, False))
512
+ route_context = ((False, False), (True, False)) * depth
513
+ attn_route_map = {'mask': route_attn, 'pos_emb': route_attn}
514
+ context_route_map = {'context': route_context, 'context_mask': route_context} if cross_attend else {}
515
+ self.net = execute_type(layers, args_route = {**attn_route_map, **context_route_map})
516
+
517
+ # keeping track of when to redraw projections for all attention layers
518
+ self.auto_check_redraw = auto_check_redraw
519
+ self.feature_redraw_interval = feature_redraw_interval
520
+ self.register_buffer('calls_since_last_redraw', torch.tensor(0))
521
+
522
+ def fix_projection_matrices_(self):
523
+ self.feature_redraw_interval = None
524
+
525
+ def check_redraw_projections(self):
526
+ if not self.training:
527
+ return
528
+
529
+ if exists(self.feature_redraw_interval) and self.calls_since_last_redraw >= self.feature_redraw_interval:
530
+ device = get_module_device(self)
531
+
532
+ fast_attentions = find_modules(self, FastAttention)
533
+ for fast_attention in fast_attentions:
534
+ fast_attention.redraw_projection_matrix(device)
535
+
536
+ self.calls_since_last_redraw.zero_()
537
+ return
538
+
539
+ self.calls_since_last_redraw += 1
540
+
541
+ def forward(self, x, output_attentions = False, **kwargs):
542
+ if self.auto_check_redraw:
543
+ self.check_redraw_projections()
544
+ return self.net(x, output_attentions = output_attentions, **kwargs)
545
+
546
+ class PerformerLM(nn.Module):
547
+ def __init__(
548
+ self,
549
+ *,
550
+ num_tokens, # num of tokens
551
+ max_seq_len, # max length of sequence
552
+ dim, # dim of tokens
553
+ depth, # layers
554
+ heads, # num of heads
555
+ dim_head = 64, # dim of heads
556
+ local_attn_heads = 0,
557
+ local_window_size = 256,
558
+ causal = False,
559
+ ff_mult = 4,
560
+ nb_features = None,
561
+ feature_redraw_interval = 1000,
562
+ reversible = False,
563
+ ff_chunks = 1,
564
+ ff_glu = False,
565
+ emb_dropout = 0.,
566
+ ff_dropout = 0.,
567
+ attn_dropout = 0.,
568
+ generalized_attention = False,
569
+ kernel_fn = nn.ReLU(),
570
+ use_scalenorm = False,
571
+ use_rezero = False,
572
+ cross_attend = False,
573
+ no_projection = False,
574
+ tie_embed = False, # False: output is num of tokens, True: output is dim of tokens //multiply final embeddings with token weights for logits, like gpt decoder//
575
+ g2v_position_emb = True, # priority: gene2vec, no embedding
576
+ auto_check_redraw = True,
577
+ qkv_bias = False
578
+ ):
579
+ super().__init__()
580
+ local_attn_heads = cast_tuple(local_attn_heads)
581
+
582
+ self.max_seq_len = max_seq_len
583
+ self.token_emb = nn.Embedding(num_tokens, dim)
584
+
585
+ if g2v_position_emb:
586
+ self.pos_emb = Gene2VecPositionalEmbedding(dim, max_seq_len)
587
+ self.layer_pos_emb = Always(None)
588
+ else:
589
+ self.pos_emb = torch.zeros_like
590
+ self.layer_pos_emb = Always(None)
591
+
592
+ self.dropout = nn.Dropout(emb_dropout)
593
+
594
+ self.performer = Performer(dim, depth, heads, dim_head, local_attn_heads, local_window_size, causal, ff_mult, nb_features, feature_redraw_interval, reversible, ff_chunks, generalized_attention, kernel_fn, use_scalenorm, use_rezero, ff_glu, ff_dropout, attn_dropout, cross_attend, no_projection, auto_check_redraw, qkv_bias)
595
+ self.norm = nn.LayerNorm(dim)
596
+ self.to_out = nn.Linear(dim, num_tokens) if not tie_embed else None
597
+
598
+ def check_redraw_projections(self):
599
+ self.performer.check_redraw_projections()
600
+
601
+ def fix_projection_matrices_(self):
602
+ self.performer.fix_projection_matrices_()
603
+
604
+ def forward(self, x, return_encodings = False, output_attentions = False, **kwargs):
605
+ b, n, device = *x.shape, x.device
606
+ assert n <= self.max_seq_len, f'sequence length {n} must be less than the max sequence length {self.max_seq_len}'
607
+
608
+ #pdb.set_trace()
609
+ # token and positional embedding
610
+ x = self.token_emb(x)
611
+ if output_attentions:
612
+ x.requires_grad_() # used for attn_map output
613
+ x += self.pos_emb(x)
614
+ x = self.dropout(x)
615
+
616
+ # performer layers
617
+ layer_pos_emb = self.layer_pos_emb(x)
618
+
619
+ if output_attentions:
620
+ x, attn_weights = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs)
621
+ # norm and to logits
622
+ x = self.norm(x)
623
+ if return_encodings:
624
+ return x, attn_weights
625
+
626
+ if exists(self.to_out):
627
+ return self.to_out(x), attn_weights
628
+
629
+ return (x @ self.token_emb.weight.t()), attn_weights
630
+ else:
631
+ x = self.performer(x, pos_emb = layer_pos_emb, output_attentions = output_attentions, **kwargs)
632
+
633
+ # norm and to logits
634
+ x = self.norm(x)
635
+ if return_encodings:
636
+ return x
637
+
638
+ if exists(self.to_out):
639
+ x = self.to_out(x)
640
+ return x
641
+
642
+ return x @ self.token_emb.weight.t()
performer_pytorch/reversible.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from operator import itemgetter
4
+ from torch.autograd.function import Function
5
+ from torch.utils.checkpoint import get_device_states, set_device_states
6
+
7
+ # for routing arguments into the functions of the reversible layer
8
+ def route_args(router, args, depth):
9
+ routed_args = [(dict(), dict()) for _ in range(depth)]
10
+ matched_keys = [key for key in args.keys() if key in router]
11
+
12
+ for key in matched_keys:
13
+ val = args[key]
14
+ for depth, ((f_args, g_args), routes) in enumerate(zip(routed_args, router[key])):
15
+ new_f_args, new_g_args = map(lambda route: ({key: val} if route else {}), routes)
16
+ routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
17
+ return routed_args
18
+
19
+ # following example for saving and setting rng here https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
20
+ class Deterministic(nn.Module):
21
+ def __init__(self, net):
22
+ super().__init__()
23
+ self.net = net
24
+ self.cpu_state = None
25
+ self.cuda_in_fwd = None
26
+ self.gpu_devices = None
27
+ self.gpu_states = None
28
+
29
+ def record_rng(self, *args):
30
+ self.cpu_state = torch.get_rng_state()
31
+ if torch.cuda._initialized:
32
+ self.cuda_in_fwd = True
33
+ self.gpu_devices, self.gpu_states = get_device_states(*args)
34
+
35
+ def forward(self, *args, record_rng = False, set_rng = False, **kwargs):
36
+ if record_rng:
37
+ self.record_rng(*args)
38
+
39
+ if not set_rng:
40
+ return self.net(*args, **kwargs)
41
+
42
+ rng_devices = []
43
+ if self.cuda_in_fwd:
44
+ rng_devices = self.gpu_devices
45
+
46
+ with torch.random.fork_rng(devices=rng_devices, enabled=True):
47
+ torch.set_rng_state(self.cpu_state)
48
+ if self.cuda_in_fwd:
49
+ set_device_states(self.gpu_devices, self.gpu_states)
50
+ return self.net(*args, **kwargs)
51
+
52
+ # heavily inspired by https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
53
+ # once multi-GPU is confirmed working, refactor and send PR back to source
54
+ class ReversibleBlock(nn.Module):
55
+ def __init__(self, f, g):
56
+ super().__init__()
57
+ self.f = Deterministic(f)
58
+ self.g = Deterministic(g)
59
+
60
+ def forward(self, x, f_args = {}, g_args = {}):
61
+ x1, x2 = torch.chunk(x, 2, dim=2)
62
+ y1, y2 = None, None
63
+
64
+ with torch.no_grad():
65
+ y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
66
+ y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
67
+
68
+ return torch.cat([y1, y2], dim=2)
69
+
70
+ def backward_pass(self, y, dy, f_args = {}, g_args = {}):
71
+ y1, y2 = torch.chunk(y, 2, dim=2)
72
+ del y
73
+
74
+ dy1, dy2 = torch.chunk(dy, 2, dim=2)
75
+ del dy
76
+
77
+ with torch.enable_grad():
78
+ y1.requires_grad = True
79
+ gy1 = self.g(y1, set_rng=True, **g_args)
80
+ torch.autograd.backward(gy1, dy2)
81
+
82
+ with torch.no_grad():
83
+ x2 = y2 - gy1
84
+ del y2, gy1
85
+
86
+ dx1 = dy1 + y1.grad
87
+ del dy1
88
+ y1.grad = None
89
+
90
+ with torch.enable_grad():
91
+ x2.requires_grad = True
92
+ fx2 = self.f(x2, set_rng=True, **f_args)
93
+ torch.autograd.backward(fx2, dx1, retain_graph=True)
94
+
95
+ with torch.no_grad():
96
+ x1 = y1 - fx2
97
+ del y1, fx2
98
+
99
+ dx2 = dy2 + x2.grad
100
+ del dy2
101
+ x2.grad = None
102
+
103
+ x = torch.cat([x1, x2.detach()], dim=2)
104
+ dx = torch.cat([dx1, dx2], dim=2)
105
+
106
+ return x, dx
107
+
108
+ class _ReversibleFunction(Function):
109
+ @staticmethod
110
+ def forward(ctx, x, blocks, args):
111
+ ctx.args = args
112
+ for block, kwarg in zip(blocks, args):
113
+ x = block(x, **kwarg)
114
+ ctx.y = x.detach()
115
+ ctx.blocks = blocks
116
+ return x
117
+
118
+ @staticmethod
119
+ def backward(ctx, dy):
120
+ y = ctx.y
121
+ args = ctx.args
122
+ for block, kwargs in zip(ctx.blocks[::-1], args[::-1]):
123
+ y, dy = block.backward_pass(y, dy, **kwargs)
124
+ return dy, None, None
125
+
126
+ class SequentialSequence(nn.Module):
127
+ def __init__(self, layers, args_route = {}):
128
+ super().__init__()
129
+ assert all(len(route) == len(layers) for route in args_route.values()), 'each argument route map must have the same depth as the number of sequential layers'
130
+ self.layers = layers
131
+ self.args_route = args_route
132
+
133
+ def forward(self, x, output_attentions = False, **kwargs):
134
+ args = route_args(self.args_route, kwargs, len(self.layers))
135
+ layers_and_args = list(zip(self.layers, args))
136
+
137
+ if output_attentions:
138
+ attn_weights = []
139
+ for (f, g), (f_args, g_args) in layers_and_args:
140
+ if output_attentions:
141
+ x = x + f(x, output_attentions = output_attentions, **f_args)[0]
142
+ attn_weights.append(f(x, output_attentions = output_attentions, **f_args)[1].unsqueeze(0))
143
+ else:
144
+ x = x + f(x, **f_args)
145
+ x = x + g(x, **g_args)
146
+ if output_attentions:
147
+ attn_weights = torch.transpose(torch.cat(attn_weights, dim=0), 0, 1) # the final dim is (batch, layer, head, len, len)
148
+ attn_weights = torch.mean(attn_weights, dim=1) # the dim is (batch, head, len, len)
149
+ return x, attn_weights
150
+ else:
151
+ return x
152
+
153
+ class ReversibleSequence(nn.Module):
154
+ def __init__(self, blocks, args_route = {}):
155
+ super().__init__()
156
+ self.args_route = args_route
157
+ self.blocks = nn.ModuleList([ReversibleBlock(f=f, g=g) for f, g in blocks])
158
+
159
+ def forward(self, x, **kwargs):
160
+ x = torch.cat([x, x], dim=-1)
161
+
162
+ blocks = self.blocks
163
+ args = route_args(self.args_route, kwargs, len(blocks))
164
+ args = list(map(lambda x: {'f_args': x[0], 'g_args': x[1]}, args))
165
+
166
+ out = _ReversibleFunction.apply(x, blocks, args)
167
+ return torch.stack(out.chunk(2, dim=-1)).sum(dim=0)