asigalov61 commited on
Commit
f43b552
1 Parent(s): fe053ba

Delete x_transformer.py

Browse files
Files changed (1) hide show
  1. x_transformer.py +0 -2001
x_transformer.py DELETED
@@ -1,2001 +0,0 @@
1
- #===================================================================================================================
2
-
3
- # X Trasformer Module
4
- # Partial x-transformers code With useful modifications
5
- #
6
- # Version 1.0
7
- #
8
- # Original source code courtesy of lucidrains
9
- # https://github.com/lucidrains/x-transformers
10
- #
11
- # Original source code retrieved on 05/10/2023
12
- #
13
- # Project Los Angeles
14
- # Tegridy Code 2023
15
-
16
- #===================================================================================================================
17
-
18
- # Critical dependencies
19
- #
20
- # !pip install torch
21
- # !pip install einops
22
-
23
- #===================================================================================================================
24
-
25
- from functools import partial
26
-
27
- import torch
28
- from torch import nn, einsum, Tensor
29
- import torch.nn.functional as F
30
-
31
- from collections import namedtuple
32
- from functools import wraps
33
- from packaging import version
34
- from dataclasses import dataclass
35
-
36
- from einops import rearrange
37
-
38
- import math
39
- from random import random
40
-
41
- from functools import partial
42
- from inspect import isfunction
43
-
44
- from dataclasses import dataclass
45
- from typing import List
46
-
47
- from einops import rearrange, repeat, reduce
48
- from einops.layers.torch import Rearrange
49
-
50
- from math import ceil
51
-
52
- from einops import rearrange, pack, unpack
53
-
54
- #===================================================================================================================
55
-
56
- # constants
57
-
58
- EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
59
-
60
- @dataclass
61
- class Intermediates:
62
- qk_similarities: Tensor = None
63
- pre_softmax_attn: Tensor = None
64
- post_softmax_attn: Tensor = None
65
-
66
- # helpers
67
-
68
- def exists(val):
69
- return val is not None
70
-
71
- def default(val, d):
72
- return val if exists(val) else d
73
-
74
- def once(fn):
75
- called = False
76
- @wraps(fn)
77
- def inner(x):
78
- nonlocal called
79
- if called:
80
- return
81
- called = True
82
- return fn(x)
83
- return inner
84
-
85
- print_once = once(print)
86
-
87
- # main class
88
-
89
- class Attend(nn.Module):
90
- def __init__(
91
- self,
92
- *,
93
- dropout = 0.,
94
- causal = False,
95
- heads = None,
96
- talking_heads = False,
97
- scale = None,
98
- qk_norm = False,
99
- flash = False,
100
- ):
101
- super().__init__()
102
- self.scale = scale
103
- self.qk_norm = qk_norm
104
- self.causal = causal
105
- self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax
106
-
107
- self.dropout = dropout
108
- self.attn_dropout = nn.Dropout(dropout)
109
-
110
- # talking heads
111
-
112
- assert not (flash and talking_heads), 'talking heads not compatible with flash attention'
113
-
114
- self.talking_heads = talking_heads
115
- if talking_heads:
116
- self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
117
- self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
118
-
119
- # flash attention
120
-
121
- self.flash = flash
122
- assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
123
-
124
- # determine efficient attention configs for cuda and cpu
125
-
126
- self.cpu_config = EfficientAttentionConfig(True, True, True)
127
- self.cuda_config = None
128
-
129
- if not torch.cuda.is_available() or not flash:
130
- return
131
-
132
- device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
133
-
134
- if device_properties.major == 8 and device_properties.minor == 0:
135
- print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
136
- self.cuda_config = EfficientAttentionConfig(True, False, False)
137
- else:
138
- print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
139
- self.cuda_config = EfficientAttentionConfig(False, True, True)
140
-
141
- def flash_attn(
142
- self,
143
- q, k, v,
144
- mask = None,
145
- attn_bias = None
146
- ):
147
- batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
148
-
149
- # Recommended for multi-query single-key-value attention by Tri Dao
150
- # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
151
-
152
- if k.ndim == 3:
153
- k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
154
-
155
- if v.ndim == 3:
156
- v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
157
-
158
- # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention
159
-
160
- if self.qk_norm:
161
- default_scale = q.shape[-1] ** -0.5
162
- q = q * (default_scale / self.scale)
163
-
164
- # Check if mask exists and expand to compatible shape
165
- # The mask is B L, so it would have to be expanded to B H N L
166
-
167
- causal = self.causal
168
-
169
- if exists(mask):
170
- assert mask.ndim == 4
171
- mask = mask.expand(batch, heads, q_len, k_len)
172
-
173
- # manually handle causal mask, if another mask was given
174
-
175
- if causal:
176
- causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
177
- mask = mask | causal_mask
178
- causal = False
179
-
180
- # handle alibi positional bias
181
- # convert from bool to float
182
-
183
- if exists(attn_bias):
184
- attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1)
185
-
186
- # if mask given, the mask would already contain the causal mask from above logic
187
- # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
188
-
189
- mask_value = -torch.finfo(q.dtype).max
190
-
191
- if exists(mask):
192
- attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
193
- elif causal:
194
- causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
195
- attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
196
- causal = False
197
-
198
- # scaled_dot_product_attention handles attn_mask either as bool or additive bias
199
- # make it an additive bias here
200
-
201
- mask = attn_bias
202
-
203
- # Check if there is a compatible device for flash attention
204
-
205
- config = self.cuda_config if is_cuda else self.cpu_config
206
-
207
- # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
208
-
209
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
210
- out = F.scaled_dot_product_attention(
211
- q, k, v,
212
- attn_mask = mask,
213
- dropout_p = self.dropout if self.training else 0.,
214
- is_causal = causal
215
- )
216
-
217
- return out, Intermediates()
218
-
219
- def forward(
220
- self,
221
- q, k, v,
222
- mask = None,
223
- attn_bias = None,
224
- prev_attn = None
225
- ):
226
- """
227
- einstein notation
228
- b - batch
229
- h - heads
230
- n, i, j - sequence length (base sequence length, source, target)
231
- d - feature dimension
232
- """
233
-
234
- n, device = q.shape[-2], q.device
235
-
236
- scale = default(self.scale, q.shape[-1] ** -0.5)
237
-
238
- if self.flash:
239
- assert not exists(prev_attn), 'residual attention not compatible with flash attention'
240
- return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
241
-
242
- kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
243
-
244
- dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
245
-
246
- if exists(prev_attn):
247
- dots = dots + prev_attn
248
-
249
- qk_similarities = dots.clone()
250
-
251
- if self.talking_heads:
252
- dots = self.pre_softmax_talking_heads(dots)
253
-
254
- if exists(attn_bias):
255
- dots = dots + attn_bias
256
-
257
- dtype = dots.dtype
258
- pre_softmax_attn = dots.clone()
259
-
260
- mask_value = -torch.finfo(dots.dtype).max
261
-
262
- if exists(mask):
263
- dots = dots.masked_fill(mask, mask_value)
264
-
265
- if self.causal:
266
- i, j = dots.shape[-2:]
267
- causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
268
- dots = dots.masked_fill(causal_mask, mask_value)
269
-
270
- attn = self.attn_fn(dots, dim = -1)
271
- attn = attn.type(dtype)
272
-
273
- post_softmax_attn = attn.clone()
274
-
275
- attn = self.attn_dropout(attn)
276
-
277
- if self.talking_heads:
278
- attn = self.post_softmax_talking_heads(attn)
279
-
280
- out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
281
-
282
- intermediates = Intermediates(
283
- qk_similarities = qk_similarities,
284
- pre_softmax_attn = pre_softmax_attn,
285
- post_softmax_attn = post_softmax_attn
286
- )
287
-
288
- return out, intermediates
289
-
290
- #===================================================================================================================
291
-
292
- # constants
293
-
294
- DEFAULT_DIM_HEAD = 64
295
-
296
- @dataclass
297
- class LayerIntermediates:
298
- hiddens: List[Tensor] = None,
299
- attn_intermediates: List[Intermediates] = None
300
-
301
- # helpers
302
-
303
- def exists(val):
304
- return val is not None
305
-
306
- def default(val, d):
307
- if exists(val):
308
- return val
309
- return d() if isfunction(d) else d
310
-
311
- def cast_tuple(val, depth):
312
- return val if isinstance(val, tuple) else (val,) * depth
313
-
314
- def maybe(fn):
315
- @wraps(fn)
316
- def inner(x, *args, **kwargs):
317
- if not exists(x):
318
- return x
319
- return fn(x, *args, **kwargs)
320
- return inner
321
-
322
- class always():
323
- def __init__(self, val):
324
- self.val = val
325
- def __call__(self, *args, **kwargs):
326
- return self.val
327
-
328
- class not_equals():
329
- def __init__(self, val):
330
- self.val = val
331
- def __call__(self, x, *args, **kwargs):
332
- return x != self.val
333
-
334
- class equals():
335
- def __init__(self, val):
336
- self.val = val
337
- def __call__(self, x, *args, **kwargs):
338
- return x == self.val
339
-
340
- # tensor helpers
341
-
342
- def max_neg_value(tensor):
343
- return -torch.finfo(tensor.dtype).max
344
-
345
- def l2norm(t, groups = 1):
346
- t = rearrange(t, '... (g d) -> ... g d', g = groups)
347
- t = F.normalize(t, p = 2, dim = -1)
348
- return rearrange(t, '... g d -> ... (g d)')
349
-
350
- def pad_at_dim(t, pad, dim = -1, value = 0.):
351
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
352
- zeros = ((0, 0) * dims_from_right)
353
- return F.pad(t, (*zeros, *pad), value = value)
354
-
355
- def or_reduce(masks):
356
- head, *body = masks
357
- for rest in body:
358
- head = head | rest
359
- return head
360
-
361
- # init helpers
362
-
363
- def init_zero_(layer):
364
- nn.init.constant_(layer.weight, 0.)
365
- if exists(layer.bias):
366
- nn.init.constant_(layer.bias, 0.)
367
-
368
- # keyword argument helpers
369
-
370
- def pick_and_pop(keys, d):
371
- values = list(map(lambda key: d.pop(key), keys))
372
- return dict(zip(keys, values))
373
-
374
- def group_dict_by_key(cond, d):
375
- return_val = [dict(),dict()]
376
- for key in d.keys():
377
- match = bool(cond(key))
378
- ind = int(not match)
379
- return_val[ind][key] = d[key]
380
- return (*return_val,)
381
-
382
- def string_begins_with(prefix, str):
383
- return str.startswith(prefix)
384
-
385
- def group_by_key_prefix(prefix, d):
386
- return group_dict_by_key(partial(string_begins_with, prefix), d)
387
-
388
- def groupby_prefix_and_trim(prefix, d):
389
- kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
390
- kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
391
- return kwargs_without_prefix, kwargs
392
-
393
- # initializations
394
-
395
- def deepnorm_init(
396
- transformer,
397
- beta,
398
- module_name_match_list = ['.ff.', '.to_v', '.to_out']
399
- ):
400
- for name, module in transformer.named_modules():
401
- if type(module) != nn.Linear:
402
- continue
403
-
404
- needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
405
- gain = beta if needs_beta_gain else 1
406
- nn.init.xavier_normal_(module.weight.data, gain = gain)
407
-
408
- if exists(module.bias):
409
- nn.init.constant_(module.bias.data, 0)
410
-
411
- # structured dropout, more effective than traditional attention dropouts
412
-
413
- def dropout_seq(seq, mask, dropout):
414
- b, n, *_, device = *seq.shape, seq.device
415
- logits = torch.randn(b, n, device = device)
416
-
417
- if exists(mask):
418
- mask_value = max_neg_value(logits)
419
- logits = logits.masked_fill(~mask, mask_value)
420
-
421
- keep_prob = 1. - dropout
422
- num_keep = max(1, int(keep_prob * n))
423
- keep_indices = logits.topk(num_keep, dim = 1).indices
424
-
425
- batch_indices = torch.arange(b, device = device)
426
- batch_indices = rearrange(batch_indices, 'b -> b 1')
427
-
428
- seq = seq[batch_indices, keep_indices]
429
-
430
- if exists(mask):
431
- seq_counts = mask.sum(dim = -1)
432
- seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
433
- keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
434
-
435
- mask = mask[batch_indices, keep_indices] & keep_mask
436
-
437
- return seq, mask
438
-
439
- # activations
440
-
441
- class ReluSquared(nn.Module):
442
- def forward(self, x):
443
- return F.relu(x) ** 2
444
-
445
- # embedding
446
-
447
- class TokenEmbedding(nn.Module):
448
- def __init__(self, dim, num_tokens, l2norm_embed = False):
449
- super().__init__()
450
- self.l2norm_embed = l2norm_embed
451
- self.emb = nn.Embedding(num_tokens, dim)
452
-
453
- def forward(self, x):
454
- token_emb = self.emb(x)
455
- return l2norm(token_emb) if self.l2norm_embed else token_emb
456
-
457
- # positional embeddings
458
-
459
- class AbsolutePositionalEmbedding(nn.Module):
460
- def __init__(self, dim, max_seq_len, l2norm_embed = False):
461
- super().__init__()
462
- self.scale = dim ** -0.5 if not l2norm_embed else 1.
463
- self.max_seq_len = max_seq_len
464
- self.l2norm_embed = l2norm_embed
465
- self.emb = nn.Embedding(max_seq_len, dim)
466
-
467
- def forward(self, x, pos = None):
468
- seq_len, device = x.shape[1], x.device
469
- assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
470
-
471
- if not exists(pos):
472
- pos = torch.arange(seq_len, device = device)
473
-
474
- pos_emb = self.emb(pos)
475
- pos_emb = pos_emb * self.scale
476
- return l2norm(pos_emb) if self.l2norm_embed else pos_emb
477
-
478
- class ScaledSinusoidalEmbedding(nn.Module):
479
- def __init__(self, dim, theta = 10000):
480
- super().__init__()
481
- assert (dim % 2) == 0
482
- self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
483
-
484
- half_dim = dim // 2
485
- freq_seq = torch.arange(half_dim).float() / half_dim
486
- inv_freq = theta ** -freq_seq
487
- self.register_buffer('inv_freq', inv_freq, persistent = False)
488
-
489
- def forward(self, x, pos = None):
490
- seq_len, device = x.shape[1], x.device
491
-
492
- if not exists(pos):
493
- pos = torch.arange(seq_len, device = device)
494
-
495
- emb = einsum('i, j -> i j', pos, self.inv_freq)
496
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
497
- return emb * self.scale
498
-
499
- class RelativePositionBias(nn.Module):
500
- def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
501
- super().__init__()
502
- self.scale = scale
503
- self.causal = causal
504
- self.num_buckets = num_buckets
505
- self.max_distance = max_distance
506
- self.relative_attention_bias = nn.Embedding(num_buckets, heads)
507
-
508
- @staticmethod
509
- def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128):
510
- ret = 0
511
- n = -relative_position
512
- if not causal:
513
- num_buckets //= 2
514
- ret += (n < 0).long() * num_buckets
515
- n = torch.abs(n)
516
- else:
517
- n = torch.max(n, torch.zeros_like(n))
518
-
519
- max_exact = num_buckets // 2
520
- is_small = n < max_exact
521
-
522
- val_if_large = max_exact + (
523
- torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
524
- ).long()
525
- val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
526
-
527
- ret += torch.where(is_small, n, val_if_large)
528
- return ret
529
-
530
- @property
531
- def device(self):
532
- return next(self.parameters()).device
533
-
534
- def forward(self, i, j):
535
- device = self.device
536
- q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
537
- k_pos = torch.arange(j, dtype = torch.long, device = device)
538
- rel_pos = k_pos[None, :] - q_pos[:, None]
539
- rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
540
- values = self.relative_attention_bias(rp_bucket)
541
- bias = rearrange(values, 'i j h -> h i j')
542
- return bias * self.scale
543
-
544
- class DynamicPositionBias(nn.Module):
545
- def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
546
- super().__init__()
547
- assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
548
- self.log_distance = log_distance
549
-
550
- self.mlp = nn.ModuleList([])
551
-
552
- self.mlp.append(nn.Sequential(
553
- nn.Linear(1, dim),
554
- nn.LayerNorm(dim) if norm else nn.Identity(),
555
- nn.SiLU()
556
- ))
557
-
558
- for _ in range(depth - 1):
559
- self.mlp.append(nn.Sequential(
560
- nn.Linear(dim, dim),
561
- nn.LayerNorm(dim) if norm else nn.Identity(),
562
- nn.SiLU()
563
- ))
564
-
565
- self.mlp.append(nn.Linear(dim, heads))
566
-
567
- @property
568
- def device(self):
569
- return next(self.parameters()).device
570
-
571
- def forward(self, i, j):
572
- assert i == j
573
- n, device = j, self.device
574
-
575
- # get the (n x n) matrix of distances
576
- seq_arange = torch.arange(n, device = device)
577
- context_arange = torch.arange(n, device = device)
578
- indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
579
- indices += (n - 1)
580
-
581
- # input to continuous positions MLP
582
- pos = torch.arange(-n + 1, n, device = device).float()
583
- pos = rearrange(pos, '... -> ... 1')
584
-
585
- if self.log_distance:
586
- pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)
587
-
588
- for layer in self.mlp:
589
- pos = layer(pos)
590
-
591
- # get position biases
592
- bias = pos[indices]
593
- bias = rearrange(bias, 'i j h -> h i j')
594
- return bias
595
-
596
- class AlibiPositionalBias(nn.Module):
597
- def __init__(self, heads, total_heads, **kwargs):
598
- super().__init__()
599
- self.heads = heads
600
- self.total_heads = total_heads
601
-
602
- slopes = Tensor(self._get_slopes(heads))
603
- slopes = rearrange(slopes, 'h -> h 1 1')
604
- self.register_buffer('slopes', slopes, persistent = False)
605
- self.register_buffer('bias', None, persistent = False)
606
-
607
- def get_bias(self, i, j, device):
608
- i_arange = torch.arange(j - i, j, device = device)
609
- j_arange = torch.arange(j, device = device)
610
- bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
611
- return bias
612
-
613
- @staticmethod
614
- def _get_slopes(heads):
615
- def get_slopes_power_of_2(n):
616
- start = (2**(-2**-(math.log2(n)-3)))
617
- ratio = start
618
- return [start*ratio**i for i in range(n)]
619
-
620
- if math.log2(heads).is_integer():
621
- return get_slopes_power_of_2(heads)
622
-
623
- closest_power_of_2 = 2 ** math.floor(math.log2(heads))
624
- return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
625
-
626
- @property
627
- def device(self):
628
- return next(self.buffers()).device
629
-
630
- def forward(self, i, j):
631
- h, device = self.total_heads, self.device
632
-
633
- if exists(self.bias) and self.bias.shape[-1] >= j:
634
- return self.bias[..., :i, :j]
635
-
636
- bias = self.get_bias(i, j, device)
637
- bias = bias * self.slopes
638
-
639
- num_heads_unalibied = h - bias.shape[0]
640
- bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0)
641
- self.register_buffer('bias', bias, persistent = False)
642
-
643
- return self.bias
644
-
645
- class LearnedAlibiPositionalBias(AlibiPositionalBias):
646
- def __init__(self, heads, total_heads):
647
- super().__init__(heads, total_heads)
648
- log_slopes = torch.log(self.slopes)
649
- self.learned_logslopes = nn.Parameter(log_slopes)
650
-
651
- def forward(self, i, j):
652
- h, i, j, device = self.heads, self.device
653
-
654
- def get_slopes(param):
655
- return pad_at_dim(param.exp(), (0, h - param.shape[0]), dim = -2)
656
-
657
- if exists(self.bias) and self.bias.shape[-1] >= j:
658
- bias = self.bias[..., :i, :j]
659
- else:
660
- bias = self.get_bias(i, j, device)
661
- self.register_buffer('bias', bias, persistent = False)
662
-
663
- slopes = get_slopes(self.learned_logslopes)
664
- bias = bias * slopes
665
-
666
- return bias
667
-
668
- class RotaryEmbedding(nn.Module):
669
- def __init__(
670
- self,
671
- dim,
672
- use_xpos = False,
673
- scale_base = 512
674
- ):
675
- super().__init__()
676
- inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
677
- self.register_buffer('inv_freq', inv_freq)
678
-
679
- if not use_xpos:
680
- self.register_buffer('scale', None)
681
- return
682
-
683
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
684
-
685
- self.scale_base = scale_base
686
- self.register_buffer('scale', scale)
687
-
688
- def forward(self, seq_len, device):
689
- t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
690
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
691
- freqs = torch.cat((freqs, freqs), dim = -1)
692
-
693
- if not exists(self.scale):
694
- return freqs, 1.
695
-
696
- power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
697
- scale = self.scale ** rearrange(power, 'n -> n 1')
698
- scale = torch.cat((scale, scale), dim = -1)
699
-
700
- return freqs, scale
701
-
702
-
703
- def rotate_half(x):
704
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
705
- x1, x2 = x.unbind(dim = -2)
706
- return torch.cat((-x2, x1), dim = -1)
707
-
708
- def apply_rotary_pos_emb(t, freqs, scale = 1):
709
- seq_len = t.shape[-2]
710
- freqs = freqs[-seq_len:, :]
711
- return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
712
-
713
- # norms
714
-
715
- class Scale(nn.Module):
716
- def __init__(self, value, fn):
717
- super().__init__()
718
- self.value = value
719
- self.fn = fn
720
-
721
- def forward(self, x, **kwargs):
722
- out = self.fn(x, **kwargs)
723
- scale_fn = lambda t: t * self.value
724
-
725
- if not isinstance(out, tuple):
726
- return scale_fn(out)
727
-
728
- return (scale_fn(out[0]), *out[1:])
729
-
730
- class ScaleNorm(nn.Module):
731
- def __init__(self, dim, eps = 1e-5):
732
- super().__init__()
733
- self.eps = eps
734
- self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
735
-
736
- def forward(self, x):
737
- norm = torch.norm(x, dim = -1, keepdim = True)
738
- return x / norm.clamp(min = self.eps) * self.g
739
-
740
- class RMSNorm(nn.Module):
741
- def __init__(self, dim, eps = 1e-8):
742
- super().__init__()
743
- self.scale = dim ** -0.5
744
- self.eps = eps
745
- self.g = nn.Parameter(torch.ones(dim))
746
-
747
- def forward(self, x):
748
- norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
749
- return x / norm.clamp(min = self.eps) * self.g
750
-
751
- # residual and residual gates
752
-
753
- class Residual(nn.Module):
754
- def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.):
755
- super().__init__()
756
- self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
757
- self.scale_residual_constant = scale_residual_constant
758
-
759
- def forward(self, x, residual):
760
- if exists(self.residual_scale):
761
- residual = residual * self.residual_scale
762
-
763
- if self.scale_residual_constant != 1:
764
- residual = residual * self.scale_residual_constant
765
-
766
- return x + residual
767
-
768
- class GRUGating(nn.Module):
769
- def __init__(self, dim, scale_residual = False, **kwargs):
770
- super().__init__()
771
- self.gru = nn.GRUCell(dim, dim)
772
- self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None
773
-
774
- def forward(self, x, residual):
775
- if exists(self.residual_scale):
776
- residual = residual * self.residual_scale
777
-
778
- gated_output = self.gru(
779
- rearrange(x, 'b n d -> (b n) d'),
780
- rearrange(residual, 'b n d -> (b n) d')
781
- )
782
-
783
- return gated_output.reshape_as(x)
784
-
785
- # token shifting
786
-
787
- def shift(t, amount, mask = None):
788
- if amount == 0:
789
- return t
790
- else:
791
- amount = min(amount, t.shape[1])
792
-
793
- if exists(mask):
794
- t = t.masked_fill(~mask[..., None], 0.)
795
-
796
- return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.)
797
-
798
- class ShiftTokens(nn.Module):
799
- def __init__(self, shifts, fn):
800
- super().__init__()
801
- self.fn = fn
802
- self.shifts = tuple(shifts)
803
-
804
- def forward(self, x, **kwargs):
805
- mask = kwargs.get('mask', None)
806
- shifts = self.shifts
807
- segments = len(shifts)
808
- feats_per_shift = x.shape[-1] // segments
809
- splitted = x.split(feats_per_shift, dim = -1)
810
- segments_to_shift, rest = splitted[:segments], splitted[segments:]
811
- segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts)))
812
- x = torch.cat((*segments_to_shift, *rest), dim = -1)
813
- return self.fn(x, **kwargs)
814
-
815
- # feedforward
816
-
817
- class GLU(nn.Module):
818
- def __init__(self, dim_in, dim_out, activation):
819
- super().__init__()
820
- self.act = activation
821
- self.proj = nn.Linear(dim_in, dim_out * 2)
822
-
823
- def forward(self, x):
824
- x, gate = self.proj(x).chunk(2, dim = -1)
825
- return x * self.act(gate)
826
-
827
- class FeedForward(nn.Module):
828
- def __init__(
829
- self,
830
- dim,
831
- dim_out = None,
832
- mult = 4,
833
- glu = False,
834
- swish = False,
835
- relu_squared = False,
836
- post_act_ln = False,
837
- dropout = 0.,
838
- no_bias = False,
839
- zero_init_output = False
840
- ):
841
- super().__init__()
842
- inner_dim = int(dim * mult)
843
- dim_out = default(dim_out, dim)
844
-
845
- if relu_squared:
846
- activation = ReluSquared()
847
- elif swish:
848
- activation = nn.SiLU()
849
- else:
850
- activation = nn.GELU()
851
-
852
- project_in = nn.Sequential(
853
- nn.Linear(dim, inner_dim, bias = not no_bias),
854
- activation
855
- ) if not glu else GLU(dim, inner_dim, activation)
856
-
857
- self.ff = nn.Sequential(
858
- project_in,
859
- nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
860
- nn.Dropout(dropout),
861
- nn.Linear(inner_dim, dim_out, bias = not no_bias)
862
- )
863
-
864
- # init last linear layer to 0
865
- if zero_init_output:
866
- init_zero_(self.ff[-1])
867
-
868
- def forward(self, x):
869
- return self.ff(x)
870
-
871
- # attention. it is all we need
872
-
873
- class Attention(nn.Module):
874
- def __init__(
875
- self,
876
- dim,
877
- dim_head = DEFAULT_DIM_HEAD,
878
- heads = 8,
879
- causal = False,
880
- flash = False,
881
- talking_heads = False,
882
- head_scale = False,
883
- sparse_topk = None,
884
- num_mem_kv = 0,
885
- dropout = 0.,
886
- on_attn = False,
887
- gate_values = False,
888
- zero_init_output = False,
889
- max_attend_past = None,
890
- qk_norm = False,
891
- qk_norm_groups = 1,
892
- qk_norm_scale = 10,
893
- qk_norm_dim_scale = False,
894
- one_kv_head = False,
895
- shared_kv = False,
896
- value_dim_head = None,
897
- tensor_product = False # https://arxiv.org/abs/2208.06061
898
- ):
899
- super().__init__()
900
- self.scale = dim_head ** -0.5
901
-
902
- self.heads = heads
903
- self.causal = causal
904
- self.max_attend_past = max_attend_past
905
-
906
- value_dim_head = default(value_dim_head, dim_head)
907
- q_dim = k_dim = dim_head * heads
908
- v_dim = out_dim = value_dim_head * heads
909
-
910
- self.one_kv_head = one_kv_head
911
- if one_kv_head:
912
- k_dim = dim_head
913
- v_dim = value_dim_head
914
- out_dim = v_dim * heads
915
-
916
- self.to_q = nn.Linear(dim, q_dim, bias = False)
917
- self.to_k = nn.Linear(dim, k_dim, bias = False)
918
-
919
- # shared key / values, for further memory savings during inference
920
- assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
921
- self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None
922
-
923
- # relations projection from tp-attention
924
- self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None
925
-
926
- # add GLU gating for aggregated values, from alphafold2
927
- self.to_v_gate = None
928
- if gate_values:
929
- self.to_v_gate = nn.Linear(dim, out_dim)
930
- nn.init.constant_(self.to_v_gate.weight, 0)
931
- nn.init.constant_(self.to_v_gate.bias, 1)
932
-
933
- # cosine sim attention
934
- self.qk_norm = qk_norm
935
- self.qk_norm_groups = qk_norm_groups
936
- self.qk_norm_scale = qk_norm_scale
937
-
938
- # whether to use the rmsnorm (equivalent to cosine sim attention when scale is equal to 1) - https://arxiv.org/abs/2302.05442
939
- self.qk_norm_dim_scale = qk_norm_dim_scale
940
-
941
- self.qk_norm_q_scale = self.qk_norm_k_scale = 1
942
- if qk_norm and qk_norm_dim_scale:
943
- self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head))
944
- self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head))
945
-
946
- assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups'
947
- assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
948
-
949
- # attend class - includes core attention algorithm + talking heads
950
-
951
- self.attend = Attend(
952
- heads = heads,
953
- causal = causal,
954
- talking_heads = talking_heads,
955
- dropout = dropout,
956
- qk_norm = qk_norm,
957
- scale = qk_norm_scale if qk_norm else self.scale,
958
- flash = flash
959
- )
960
-
961
- # head scaling
962
- self.head_scale = head_scale
963
- if head_scale:
964
- self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))
965
-
966
- # explicit topk sparse attention
967
- self.sparse_topk = sparse_topk
968
-
969
- # add memory key / values
970
- self.num_mem_kv = num_mem_kv
971
- if num_mem_kv > 0:
972
- self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
973
- self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
974
-
975
- # attention on attention
976
- self.attn_on_attn = on_attn
977
- self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False)
978
-
979
- # init output projection 0
980
- if zero_init_output:
981
- init_zero_(self.to_out)
982
-
983
- def forward(
984
- self,
985
- x,
986
- context = None,
987
- mask = None,
988
- context_mask = None,
989
- attn_mask = None,
990
- rel_pos = None,
991
- rotary_pos_emb = None,
992
- prev_attn = None,
993
- mem = None
994
- ):
995
- b, n, _, h, head_scale, device, has_context = *x.shape, self.heads, self.head_scale, x.device, exists(context)
996
- kv_input = default(context, x)
997
-
998
- q_input = x
999
- k_input = kv_input
1000
- v_input = kv_input
1001
- r_input = x
1002
-
1003
- if exists(mem):
1004
- k_input = torch.cat((mem, k_input), dim = -2)
1005
- v_input = torch.cat((mem, v_input), dim = -2)
1006
-
1007
- q = self.to_q(q_input)
1008
- k = self.to_k(k_input)
1009
- v = self.to_v(v_input) if exists(self.to_v) else k
1010
- r = self.to_r(r_input) if exists(self.to_r) else None
1011
-
1012
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1013
-
1014
- if not self.one_kv_head:
1015
- k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r))
1016
-
1017
- if self.qk_norm:
1018
- qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
1019
- q, k = map(qk_l2norm, (q, k))
1020
- scale = self.qk_norm_scale
1021
-
1022
- q = q * self.qk_norm_q_scale
1023
- k = k * self.qk_norm_k_scale
1024
-
1025
- if exists(rotary_pos_emb) and not has_context:
1026
- freqs, xpos_scale = rotary_pos_emb
1027
- l = freqs.shape[-1]
1028
-
1029
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1030
- (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v))
1031
-
1032
- ql, kl, vl = map(lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale)))
1033
- q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr)))
1034
-
1035
- input_mask = default(context_mask, mask)
1036
-
1037
- if self.num_mem_kv > 0:
1038
- mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
1039
-
1040
- if self.qk_norm:
1041
- mem_k = l2norm(mem_k)
1042
- mem_k = mem_k * self.qk_norm_k_scale
1043
-
1044
- k = torch.cat((mem_k, k), dim = -2)
1045
- v = torch.cat((mem_v, v), dim = -2)
1046
-
1047
- if exists(input_mask):
1048
- input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
1049
-
1050
-
1051
- i, j = map(lambda t: t.shape[-2], (q, k))
1052
-
1053
- # determine masking
1054
-
1055
- mask_value = max_neg_value(q)
1056
- masks = []
1057
- final_attn_mask = None
1058
-
1059
- if exists(input_mask):
1060
- input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
1061
- masks.append(~input_mask)
1062
-
1063
- if exists(attn_mask):
1064
- assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4'
1065
- if attn_mask.ndim == 2:
1066
- attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j')
1067
- elif attn_mask.ndim == 3:
1068
- attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j')
1069
- masks.append(~attn_mask)
1070
-
1071
- if exists(self.max_attend_past):
1072
- range_q = torch.arange(j - i, j, device = device)
1073
- range_k = torch.arange(j, device = device)
1074
- dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
1075
- max_attend_past_mask = dist > self.max_attend_past
1076
- masks.append(max_attend_past_mask)
1077
-
1078
- if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
1079
- top, _ = dots.topk(self.sparse_topk, dim = -1)
1080
- vk = rearrange(top[..., -1], '... -> ... 1')
1081
- sparse_topk_mask = dots < vk
1082
- masks.append(sparse_topk_mask)
1083
-
1084
- if len(masks) > 0:
1085
- final_attn_mask = or_reduce(masks)
1086
-
1087
- # prepare relative positional bias, if needed
1088
-
1089
- attn_bias = None
1090
- if exists(rel_pos):
1091
- attn_bias = rel_pos(i, j)
1092
-
1093
- # attention is all we need
1094
-
1095
- out, intermediates = self.attend(
1096
- q, k, v,
1097
- mask = final_attn_mask,
1098
- attn_bias = attn_bias,
1099
- prev_attn = prev_attn
1100
- )
1101
-
1102
- # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1103
-
1104
- if exists(r):
1105
- out = out * r + out
1106
-
1107
- # normformer scaling of heads
1108
-
1109
- if head_scale:
1110
- out = out * self.head_scale_params
1111
-
1112
- # merge heads
1113
-
1114
- out = rearrange(out, 'b h n d -> b n (h d)')
1115
-
1116
- # alphafold2 styled gating of the values
1117
-
1118
- if exists(self.to_v_gate):
1119
- gates = self.to_v_gate(x)
1120
- out = out * gates.sigmoid()
1121
-
1122
- # combine the heads
1123
-
1124
- out = self.to_out(out)
1125
-
1126
- if exists(mask):
1127
- mask = rearrange(mask, 'b n -> b n 1')
1128
- out = out.masked_fill(~mask, 0.)
1129
-
1130
- return out, intermediates
1131
-
1132
- class AttentionLayers(nn.Module):
1133
- def __init__(
1134
- self,
1135
- dim,
1136
- depth,
1137
- heads = 8,
1138
- causal = False,
1139
- cross_attend = False,
1140
- only_cross = False,
1141
- use_scalenorm = False,
1142
- use_rmsnorm = False,
1143
- alibi_pos_bias = False,
1144
- alibi_num_heads = None,
1145
- alibi_learned = False,
1146
- rel_pos_bias = False,
1147
- rel_pos_num_buckets = 32,
1148
- rel_pos_max_distance = 128,
1149
- dynamic_pos_bias = False,
1150
- dynamic_pos_bias_log_distance = False,
1151
- dynamic_pos_bias_mlp_depth = 2,
1152
- dynamic_pos_bias_norm = False,
1153
- rotary_pos_emb = False,
1154
- rotary_emb_dim = None,
1155
- rotary_xpos = False,
1156
- rotary_xpos_scale_base = 512,
1157
- custom_layers = None,
1158
- sandwich_coef = None,
1159
- par_ratio = None,
1160
- residual_attn = False,
1161
- cross_residual_attn = False,
1162
- macaron = False,
1163
- pre_norm = True,
1164
- gate_residual = False,
1165
- scale_residual = False,
1166
- scale_residual_constant = 1.,
1167
- deepnorm = False,
1168
- shift_tokens = 0,
1169
- sandwich_norm = False,
1170
- resi_dual = False,
1171
- zero_init_branch_output = False,
1172
- layer_dropout = 0.,
1173
- cross_attn_tokens_dropout = 0.,
1174
- **kwargs
1175
- ):
1176
- super().__init__()
1177
- rotary_pos_emb = rotary_pos_emb or rotary_xpos
1178
-
1179
- ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
1180
- attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1181
-
1182
- dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1183
-
1184
- self.dim = dim
1185
- self.depth = depth
1186
- self.layers = nn.ModuleList([])
1187
-
1188
- self.has_pos_emb = rel_pos_bias or rotary_pos_emb
1189
-
1190
- rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1191
-
1192
- assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1193
- self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base) if rotary_pos_emb else None
1194
-
1195
- assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
1196
- assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
1197
-
1198
- # relative positional bias
1199
-
1200
- flash_attn = attn_kwargs.get('flash', False)
1201
- assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1202
-
1203
- self.rel_pos = None
1204
- if rel_pos_bias:
1205
- assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1206
- self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1207
- elif dynamic_pos_bias:
1208
- assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1209
- self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1210
- elif alibi_pos_bias:
1211
- alibi_num_heads = default(alibi_num_heads, heads)
1212
- assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1213
- alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias
1214
- self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, total_heads = heads)
1215
-
1216
- # determine deepnorm and residual scale
1217
-
1218
- if deepnorm:
1219
- assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
1220
- pre_norm = sandwich_norm = resi_dual = False
1221
- scale_residual = True
1222
- scale_residual_constant = (2 * depth) ** 0.25
1223
-
1224
- assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1225
- assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1226
- assert not (not pre_norm and resi_dual), 'resiDualcannot be used when not using prenorm'
1227
- self.pre_norm = pre_norm
1228
- self.sandwich_norm = sandwich_norm
1229
- self.resi_dual = resi_dual
1230
-
1231
- self.residual_attn = residual_attn
1232
- self.cross_residual_attn = cross_residual_attn
1233
- self.cross_attend = cross_attend
1234
-
1235
- norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
1236
- norm_class = RMSNorm if use_rmsnorm else norm_class
1237
- norm_fn = partial(norm_class, dim)
1238
-
1239
- if cross_attend and not only_cross:
1240
- default_block = ('a', 'c', 'f')
1241
- elif cross_attend and only_cross:
1242
- default_block = ('c', 'f')
1243
- else:
1244
- default_block = ('a', 'f')
1245
-
1246
- if macaron:
1247
- default_block = ('f',) + default_block
1248
-
1249
- # zero init
1250
-
1251
- if zero_init_branch_output:
1252
- attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1253
- ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1254
-
1255
- # calculate layer block order
1256
-
1257
- if exists(custom_layers):
1258
- layer_types = custom_layers
1259
- elif exists(par_ratio):
1260
- par_depth = depth * len(default_block)
1261
- assert 1 < par_ratio <= par_depth, 'par ratio out of range'
1262
- default_block = tuple(filter(not_equals('f'), default_block))
1263
- par_attn = par_depth // par_ratio
1264
- depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
1265
- par_width = (depth_cut + depth_cut // par_attn) // par_attn
1266
- assert len(default_block) <= par_width, 'default block is too large for par_ratio'
1267
- par_block = default_block + ('f',) * (par_width - len(default_block))
1268
- par_head = par_block * par_attn
1269
- layer_types = par_head + ('f',) * (par_depth - len(par_head))
1270
- elif exists(sandwich_coef):
1271
- assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1272
- layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1273
- else:
1274
- layer_types = default_block * depth
1275
-
1276
- self.layer_types = layer_types
1277
- self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1278
-
1279
- # stochastic depth
1280
-
1281
- self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
1282
-
1283
- # structured dropout for cross attending
1284
-
1285
- self.cross_attn_tokens_dropout = cross_attn_tokens_dropout
1286
-
1287
- # calculate token shifting
1288
-
1289
- shift_tokens = cast_tuple(shift_tokens, len(layer_types))
1290
-
1291
- # iterate and construct layers
1292
-
1293
- for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
1294
- is_last_layer = ind == (len(self.layer_types) - 1)
1295
-
1296
- if layer_type == 'a':
1297
- layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
1298
- elif layer_type == 'c':
1299
- layer = Attention(dim, heads = heads, **attn_kwargs)
1300
- elif layer_type == 'f':
1301
- layer = FeedForward(dim, **ff_kwargs)
1302
- layer = layer if not macaron else Scale(0.5, layer)
1303
- else:
1304
- raise Exception(f'invalid layer type {layer_type}')
1305
-
1306
- if layer_shift_tokens > 0:
1307
- shift_range_upper = layer_shift_tokens + 1
1308
- shift_range_lower = -layer_shift_tokens if not causal else 0
1309
- layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer)
1310
-
1311
- residual_fn = GRUGating if gate_residual else Residual
1312
- residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1313
-
1314
- pre_branch_norm = norm_fn() if pre_norm else None
1315
- post_branch_norm = norm_fn() if sandwich_norm else None
1316
- post_main_norm = norm_fn() if (resi_dual or not pre_norm) and not is_last_layer else None
1317
-
1318
- norms = nn.ModuleList([
1319
- pre_branch_norm,
1320
- post_branch_norm,
1321
- post_main_norm
1322
- ])
1323
-
1324
- self.layers.append(nn.ModuleList([
1325
- norms,
1326
- layer,
1327
- residual
1328
- ]))
1329
-
1330
- if deepnorm:
1331
- init_gain = (8 * depth) ** -0.25
1332
- deepnorm_init(self, init_gain)
1333
-
1334
- def forward(
1335
- self,
1336
- x,
1337
- context = None,
1338
- mask = None,
1339
- context_mask = None,
1340
- attn_mask = None,
1341
- self_attn_context_mask = None,
1342
- mems = None,
1343
- return_hiddens = False
1344
- ):
1345
- assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1346
-
1347
- hiddens = []
1348
- intermediates = []
1349
- prev_attn = None
1350
- prev_cross_attn = None
1351
-
1352
- mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
1353
-
1354
- rotary_pos_emb = None
1355
- if exists(self.rotary_pos_emb):
1356
- max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
1357
- rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)
1358
-
1359
- outer_residual = x
1360
-
1361
- for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)):
1362
- is_last = ind == (len(self.layers) - 1)
1363
-
1364
- if self.training and layer_dropout > 0. and random() < layer_dropout:
1365
- continue
1366
-
1367
- if layer_type == 'a':
1368
- if return_hiddens:
1369
- hiddens.append(x)
1370
- layer_mem = mems.pop(0) if mems else None
1371
-
1372
- if layer_type == 'c':
1373
- if self.training and self.cross_attn_tokens_dropout > 0.:
1374
- context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)
1375
-
1376
- inner_residual = x
1377
-
1378
- pre_norm, post_branch_norm, post_main_norm = norm
1379
-
1380
- if exists(pre_norm) and not self.resi_dual:
1381
- x = pre_norm(x)
1382
-
1383
- if layer_type == 'a':
1384
- out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
1385
- elif layer_type == 'c':
1386
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
1387
- elif layer_type == 'f':
1388
- out = block(x)
1389
-
1390
- if self.resi_dual:
1391
- outer_residual = residual_fn(out, outer_residual)
1392
-
1393
- if exists(post_branch_norm):
1394
- out = post_branch_norm(out)
1395
-
1396
- x = residual_fn(out, inner_residual)
1397
-
1398
- if layer_type in ('a', 'c') and return_hiddens:
1399
- intermediates.append(inter)
1400
-
1401
- if layer_type == 'a' and self.residual_attn:
1402
- prev_attn = inter.pre_softmax_attn
1403
- elif layer_type == 'c' and self.cross_residual_attn:
1404
- prev_cross_attn = inter.pre_softmax_attn
1405
-
1406
- if exists(post_main_norm):
1407
- x = post_main_norm(x)
1408
-
1409
- if self.resi_dual:
1410
- x = x + pre_norm(outer_residual)
1411
-
1412
- if return_hiddens:
1413
- intermediates = LayerIntermediates(
1414
- hiddens = hiddens,
1415
- attn_intermediates = intermediates
1416
- )
1417
-
1418
- return x, intermediates
1419
-
1420
- return x
1421
-
1422
- class Encoder(AttentionLayers):
1423
- def __init__(self, **kwargs):
1424
- assert 'causal' not in kwargs, 'cannot set causality on encoder'
1425
- super().__init__(causal = False, **kwargs)
1426
-
1427
- class Decoder(AttentionLayers):
1428
- def __init__(self, **kwargs):
1429
- assert 'causal' not in kwargs, 'cannot set causality on decoder'
1430
- super().__init__(causal = True, **kwargs)
1431
-
1432
- class CrossAttender(AttentionLayers):
1433
- def __init__(self, **kwargs):
1434
- super().__init__(cross_attend = True, only_cross = True, **kwargs)
1435
-
1436
- class ViTransformerWrapper(nn.Module):
1437
- def __init__(
1438
- self,
1439
- *,
1440
- image_size,
1441
- patch_size,
1442
- attn_layers,
1443
- channels = 3,
1444
- num_classes = None,
1445
- dropout = 0.,
1446
- post_emb_norm = False,
1447
- emb_dropout = 0.
1448
- ):
1449
- super().__init__()
1450
- assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder'
1451
- assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
1452
- dim = attn_layers.dim
1453
- num_patches = (image_size // patch_size) ** 2
1454
- patch_dim = channels * patch_size ** 2
1455
-
1456
- self.patch_size = patch_size
1457
-
1458
- self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1459
-
1460
- self.patch_to_embedding = nn.Sequential(
1461
- nn.LayerNorm(patch_dim),
1462
- nn.Linear(patch_dim, dim),
1463
- nn.LayerNorm(dim)
1464
- )
1465
-
1466
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1467
- self.dropout = nn.Dropout(emb_dropout)
1468
-
1469
- self.attn_layers = attn_layers
1470
- self.norm = nn.LayerNorm(dim)
1471
- self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity()
1472
-
1473
- def forward(
1474
- self,
1475
- img,
1476
- return_embeddings = False
1477
- ):
1478
- p = self.patch_size
1479
-
1480
- x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
1481
- x = self.patch_to_embedding(x)
1482
- n = x.shape[1]
1483
-
1484
- x = x + self.pos_embedding[:, :n]
1485
-
1486
- x = self.post_emb_norm(x)
1487
- x = self.dropout(x)
1488
-
1489
- x = self.attn_layers(x)
1490
- x = self.norm(x)
1491
-
1492
- if not exists(self.mlp_head) or return_embeddings:
1493
- return x
1494
-
1495
- x = x.mean(dim = -2)
1496
- return self.mlp_head(x)
1497
-
1498
- class TransformerWrapper(nn.Module):
1499
- def __init__(
1500
- self,
1501
- *,
1502
- num_tokens,
1503
- max_seq_len,
1504
- attn_layers,
1505
- emb_dim = None,
1506
- max_mem_len = 0.,
1507
- shift_mem_down = 0,
1508
- emb_dropout = 0.,
1509
- post_emb_norm = False,
1510
- num_memory_tokens = None,
1511
- tie_embedding = False,
1512
- logits_dim = None,
1513
- use_abs_pos_emb = True,
1514
- scaled_sinu_pos_emb = False,
1515
- l2norm_embed = False,
1516
- emb_frac_gradient = 1. # GLM-130B and Cogview successfully used this, set at 0.1
1517
- ):
1518
- super().__init__()
1519
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1520
-
1521
- dim = attn_layers.dim
1522
- emb_dim = default(emb_dim, dim)
1523
- self.emb_dim = emb_dim
1524
- self.num_tokens = num_tokens
1525
- self.token_pad = num_tokens
1526
-
1527
- self.max_seq_len = max_seq_len
1528
- self.max_mem_len = max_mem_len
1529
- self.shift_mem_down = shift_mem_down
1530
-
1531
- self.l2norm_embed = l2norm_embed
1532
- self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
1533
-
1534
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1535
- self.pos_emb = always(0)
1536
- elif scaled_sinu_pos_emb:
1537
- self.pos_emb = ScaledSinusoidalEmbedding(emb_dim)
1538
- else:
1539
- self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed)
1540
-
1541
- self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290
1542
-
1543
- self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity()
1544
- self.emb_dropout = nn.Dropout(emb_dropout)
1545
-
1546
- self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
1547
- self.attn_layers = attn_layers
1548
- self.norm = nn.LayerNorm(dim)
1549
-
1550
- self.init_()
1551
-
1552
- logits_dim = default(logits_dim, num_tokens)
1553
- self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
1554
-
1555
- # memory tokens (like [cls]) from Memory Transformers paper
1556
- num_memory_tokens = default(num_memory_tokens, 0)
1557
- self.num_memory_tokens = num_memory_tokens
1558
- if num_memory_tokens > 0:
1559
- self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
1560
-
1561
- def init_(self):
1562
- if self.l2norm_embed:
1563
- nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
1564
- if not isinstance(self.pos_emb, always):
1565
- nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
1566
- return
1567
-
1568
- nn.init.kaiming_normal_(self.token_emb.emb.weight)
1569
-
1570
- def forward(
1571
- self,
1572
- x,
1573
- return_embeddings = False,
1574
- return_logits_and_embeddings = False,
1575
- return_intermediates = False,
1576
- mask = None,
1577
- return_mems = False,
1578
- return_attn = False,
1579
- mems = None,
1580
- pos = None,
1581
- prepend_embeds = None,
1582
- sum_embeds = None,
1583
- **kwargs
1584
- ):
1585
- b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient
1586
- return_hiddens = return_mems | return_attn
1587
-
1588
- # absolute positional embedding
1589
-
1590
- external_pos_emb = exists(pos) and pos.dtype != torch.long
1591
- pos_emb = self.pos_emb(x, pos = pos) if not external_pos_emb else pos
1592
- x = self.token_emb(x) + pos_emb
1593
-
1594
- # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training
1595
-
1596
- if exists(sum_embeds):
1597
- x = x + sum_embeds
1598
-
1599
- # post embedding norm, purportedly leads to greater stabilization
1600
-
1601
- x = self.post_emb_norm(x)
1602
-
1603
- # whether to append embeds, as in PaLI, for image embeddings
1604
-
1605
- if exists(prepend_embeds):
1606
- prepend_seq, prepend_dim = prepend_embeds.shape[1:]
1607
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
1608
-
1609
- x = torch.cat((prepend_embeds, x), dim = -2)
1610
-
1611
- # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
1612
-
1613
- if emb_frac_gradient < 1:
1614
- assert emb_frac_gradient > 0
1615
- x = x * emb_frac_gradient + x.detach() * (1 - emb_frac_gradient)
1616
-
1617
- # embedding dropout
1618
-
1619
- x = self.emb_dropout(x)
1620
-
1621
- x = self.project_emb(x)
1622
-
1623
- if num_mem > 0:
1624
- mem = repeat(self.memory_tokens, 'n d -> b n d', b = b)
1625
- x = torch.cat((mem, x), dim = 1)
1626
-
1627
- # auto-handle masking after appending memory tokens
1628
- if exists(mask):
1629
- mask = pad_at_dim(mask, (num_mem, 0), dim = -1, value = True)
1630
-
1631
- if self.shift_mem_down and exists(mems):
1632
- mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:]
1633
- mems = [*mems_r, *mems_l]
1634
-
1635
- if return_hiddens:
1636
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1637
- else:
1638
- x = self.attn_layers(x, mask = mask, mems = mems, **kwargs)
1639
-
1640
- x = self.norm(x)
1641
-
1642
- mem, x = x[:, :num_mem], x[:, num_mem:]
1643
-
1644
- if return_logits_and_embeddings:
1645
- out = (self.to_logits(x), x)
1646
- elif return_embeddings:
1647
- out = x
1648
- else:
1649
- out = self.to_logits(x)
1650
-
1651
- if return_intermediates:
1652
- return out, intermediates
1653
-
1654
- if return_mems:
1655
- hiddens = intermediates.hiddens
1656
- new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens
1657
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
1658
- return out, new_mems
1659
-
1660
- if return_attn:
1661
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1662
- return out, attn_maps
1663
-
1664
- return out
1665
-
1666
- class ContinuousTransformerWrapper(nn.Module):
1667
- def __init__(
1668
- self,
1669
- *,
1670
- max_seq_len,
1671
- attn_layers,
1672
- dim_in = None,
1673
- dim_out = None,
1674
- emb_dim = None,
1675
- post_emb_norm = False,
1676
- emb_dropout = 0.,
1677
- use_abs_pos_emb = True,
1678
- scaled_sinu_pos_emb = False
1679
- ):
1680
- super().__init__()
1681
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
1682
-
1683
- dim = attn_layers.dim
1684
-
1685
- self.max_seq_len = max_seq_len
1686
-
1687
- if not (use_abs_pos_emb and not attn_layers.has_pos_emb):
1688
- self.pos_emb = always(0)
1689
- elif scaled_sinu_pos_emb:
1690
- self.pos_emb = ScaledSinusoidalEmbedding(dim)
1691
- else:
1692
- self.pos_emb = AbsolutePositionalEmbedding(dim, max_seq_len)
1693
-
1694
- self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity()
1695
- self.emb_dropout = nn.Dropout(emb_dropout)
1696
-
1697
- self.project_in = nn.Linear(dim_in, dim) if exists(dim_in) else nn.Identity()
1698
-
1699
- self.attn_layers = attn_layers
1700
- self.norm = nn.LayerNorm(dim)
1701
-
1702
- self.project_out = nn.Linear(dim, dim_out) if exists(dim_out) else nn.Identity()
1703
-
1704
- def forward(
1705
- self,
1706
- x,
1707
- return_embeddings = False,
1708
- return_intermediates = False,
1709
- mask = None,
1710
- return_attn = False,
1711
- mems = None,
1712
- pos = None,
1713
- prepend_embeds = None,
1714
- **kwargs
1715
- ):
1716
- x = self.project_in(x)
1717
- x = x + self.pos_emb(x, pos = pos)
1718
-
1719
- x = self.post_emb_norm(x)
1720
-
1721
- # whether to append embeds, as in PaLI, for image embeddings
1722
-
1723
- if exists(prepend_embeds):
1724
- _, prepend_dim = prepend_embeds.shape[1:]
1725
- assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as model dimensions'
1726
-
1727
- x = torch.cat((prepend_embeds, x), dim = -2)
1728
-
1729
- x = self.emb_dropout(x)
1730
-
1731
- x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs)
1732
- x = self.norm(x)
1733
-
1734
- out = self.project_out(x) if not return_embeddings else x
1735
-
1736
- if return_intermediates:
1737
- return out, intermediates
1738
-
1739
- if return_attn:
1740
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
1741
- return out, attn_maps
1742
-
1743
- return out
1744
-
1745
- class XTransformer(nn.Module):
1746
- def __init__(
1747
- self,
1748
- *,
1749
- dim,
1750
- tie_token_emb = False,
1751
- ignore_index = -100,
1752
- pad_value = 0,
1753
- deepnorm = False,
1754
- cross_attn_tokens_dropout = 0.,
1755
- **kwargs
1756
- ):
1757
- super().__init__()
1758
- enc_kwargs, kwargs = groupby_prefix_and_trim('enc_', kwargs)
1759
- dec_kwargs, kwargs = groupby_prefix_and_trim('dec_', kwargs)
1760
-
1761
- assert 'dim' not in enc_kwargs and 'dim' not in dec_kwargs, 'dimension of either encoder or decoder must be set with `dim` keyword'
1762
- enc_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], enc_kwargs)
1763
- enc_transformer_kwargs['emb_dropout'] = enc_kwargs.pop('emb_dropout', 0)
1764
- enc_transformer_kwargs['num_memory_tokens'] = enc_kwargs.pop('num_memory_tokens', None)
1765
- enc_transformer_kwargs['scaled_sinu_pos_emb'] = enc_kwargs.pop('scaled_sinu_pos_emb', False)
1766
- enc_transformer_kwargs['use_abs_pos_emb'] = enc_kwargs.pop('use_abs_pos_emb', True)
1767
-
1768
- dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
1769
- dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)
1770
- dec_transformer_kwargs['scaled_sinu_pos_emb'] = dec_kwargs.pop('scaled_sinu_pos_emb', False)
1771
- dec_transformer_kwargs['use_abs_pos_emb'] = dec_kwargs.pop('use_abs_pos_emb', True)
1772
-
1773
- self.cross_attn_tokens_dropout = cross_attn_tokens_dropout # how many tokens from the encoder to dropout when cross attending from decoder - seen in a couple papers, including Perceiver AR - this will also be very effective regularization when cross attending to very long memories
1774
-
1775
- if deepnorm:
1776
- enc_kwargs['scale_residual'] = True
1777
- dec_kwargs['scale_residual'] = True
1778
-
1779
- enc_depth = enc_kwargs['depth']
1780
- dec_depth = dec_kwargs['depth']
1781
-
1782
- enc_kwargs['scale_residual_constant'] = 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625
1783
- dec_kwargs['scale_residual_constant'] = (3 * dec_depth) ** 0.25
1784
-
1785
- self.encoder = TransformerWrapper(
1786
- **enc_transformer_kwargs,
1787
- attn_layers = Encoder(dim = dim, **enc_kwargs)
1788
- )
1789
-
1790
- self.decoder = TransformerWrapper(
1791
- **dec_transformer_kwargs,
1792
- attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
1793
- )
1794
-
1795
- if deepnorm:
1796
- deepnorm_init(self.encoder, 0.87 * ((enc_depth ** 4) * dec_depth) ** -0.0625)
1797
- deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25)
1798
-
1799
- if tie_token_emb:
1800
- self.decoder.token_emb = self.encoder.token_emb
1801
-
1802
- self.decoder = AutoregressiveWrapper(self.decoder, ignore_index=ignore_index, pad_value=pad_value)
1803
-
1804
- @torch.no_grad()
1805
- def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
1806
- encodings = self.encoder(seq_in, mask = mask, attn_mask = attn_mask, return_embeddings = True)
1807
- return self.decoder.generate(seq_out_start, seq_len, context = encodings, context_mask = mask, **kwargs)
1808
-
1809
- def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
1810
-
1811
- if exists(src_prepend_embeds) and exists(mask):
1812
- mask = pad_at_dim(mask, (src_prepend_embeds.shape[-2], 0), dim = -1, value = True)
1813
-
1814
- enc = self.encoder(src, mask = mask, attn_mask = attn_mask, prepend_embeds = src_prepend_embeds, return_embeddings = True)
1815
-
1816
- if self.training and self.cross_attn_tokens_dropout > 0:
1817
- enc, mask = dropout_seq(enc, mask, self.cross_attn_tokens_dropout)
1818
-
1819
- out = self.decoder(tgt, context = enc, context_mask = mask)
1820
- return out
1821
-
1822
- #===================================================================================================================
1823
-
1824
- def exists(val):
1825
- return val is not None
1826
-
1827
- def eval_decorator(fn):
1828
- def inner(self, *args, **kwargs):
1829
- was_training = self.training
1830
- self.eval()
1831
- out = fn(self, *args, **kwargs)
1832
- self.train(was_training)
1833
- return out
1834
- return inner
1835
-
1836
- # nucleus
1837
-
1838
- def top_p(logits, thres = 0.9):
1839
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
1840
- cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1841
-
1842
- sorted_indices_to_remove = cum_probs > (1 - thres)
1843
- sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
1844
- sorted_indices_to_remove[:, 0] = 0
1845
-
1846
- sorted_logits[sorted_indices_to_remove] = float('-inf')
1847
- return sorted_logits.scatter(1, sorted_indices, sorted_logits)
1848
-
1849
- # topk
1850
-
1851
- def top_k(logits, thres = 0.9):
1852
- k = ceil((1 - thres) * logits.shape[-1])
1853
- val, ind = torch.topk(logits, k)
1854
- probs = torch.full_like(logits, float('-inf'))
1855
- probs.scatter_(1, ind, val)
1856
- return probs
1857
-
1858
- # top_a
1859
-
1860
- def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
1861
- probs = F.softmax(logits, dim=-1)
1862
- limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
1863
- logits[probs < limit] = float('-inf')
1864
- logits[probs >= limit] = 1
1865
- return logits
1866
-
1867
- # autoregressive wrapper class
1868
-
1869
- class AutoregressiveWrapper(nn.Module):
1870
- def __init__(
1871
- self,
1872
- net,
1873
- ignore_index = -100,
1874
- pad_value = 0,
1875
- mask_prob = 0.
1876
- ):
1877
- super().__init__()
1878
- self.pad_value = pad_value
1879
- self.ignore_index = ignore_index
1880
-
1881
- self.net = net
1882
- self.max_seq_len = net.max_seq_len
1883
-
1884
- # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
1885
- assert mask_prob < 1.
1886
- self.mask_prob = mask_prob
1887
-
1888
- @torch.no_grad()
1889
- @eval_decorator
1890
- def generate(
1891
- self,
1892
- start_tokens,
1893
- seq_len,
1894
- eos_token = None,
1895
- temperature = 1.,
1896
- filter_logits_fn = top_k,
1897
- filter_thres = 0.9,
1898
- min_p_pow = 2.0,
1899
- min_p_ratio = 0.02,
1900
- verbose=True,
1901
- return_prime=False,
1902
- **kwargs
1903
- ):
1904
- device = start_tokens.device
1905
- num_dims = start_tokens.ndim
1906
-
1907
- start_tokens, ps = pack([start_tokens], '* n')
1908
-
1909
- b, t = start_tokens.shape
1910
-
1911
- out = start_tokens
1912
-
1913
- if verbose:
1914
- print("Generating sequence of max length:", seq_len)
1915
-
1916
- for s in range(seq_len):
1917
- x = out[:, -self.max_seq_len:]
1918
-
1919
- logits = self.net(x, **kwargs)[:, -1]
1920
-
1921
- if filter_logits_fn in {top_k, top_p}:
1922
- filtered_logits = filter_logits_fn(logits, thres = filter_thres)
1923
- probs = F.softmax(filtered_logits / temperature, dim=-1)
1924
-
1925
- elif filter_logits_fn is top_a:
1926
- filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio)
1927
- probs = F.softmax(filtered_logits / temperature, dim=-1)
1928
-
1929
- sample = torch.multinomial(probs, 1)
1930
-
1931
- out = torch.cat((out, sample), dim=-1)
1932
-
1933
- if verbose:
1934
- if s % 32 == 0:
1935
- print(s, '/', seq_len)
1936
-
1937
- if exists(eos_token):
1938
- is_eos_tokens = (out == eos_token)
1939
-
1940
- if is_eos_tokens.any(dim = -1).all():
1941
- # mask out everything after the eos tokens
1942
- shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
1943
- mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
1944
- out = out.masked_fill(mask, self.pad_value)
1945
-
1946
- if verbose:
1947
- print('Model called the end of sequence at:', s, '/', seq_len)
1948
-
1949
- break
1950
-
1951
- if return_prime:
1952
- return out[:, :]
1953
-
1954
- else:
1955
- return out[:, t:]
1956
-
1957
- out, = unpack(out, ps, '* n')
1958
-
1959
- return out
1960
-
1961
- def compute_accuracy(self, logits, labels):
1962
- out = torch.argmax(logits, dim=-1)
1963
- out = out.flatten()
1964
- labels = labels.flatten()
1965
-
1966
- mask = (labels != 999999) # dummy pad value / supposed to be self.token_pad / will fix later
1967
- out = out[mask]
1968
- labels = labels[mask]
1969
-
1970
- num_right = (out == labels)
1971
- num_right = torch.sum(num_right).type(torch.float32)
1972
-
1973
- acc = num_right / len(labels)
1974
- return acc
1975
-
1976
- def forward(self, x, labels = None, **kwargs):
1977
- seq, ignore_index = x.shape[1], self.ignore_index
1978
-
1979
- inp, target = x[:, :-1], x[:, 1:]
1980
-
1981
- if self.mask_prob > 0.:
1982
- rand = torch.randn(inp.shape, device = x.device)
1983
- rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
1984
- num_mask = min(int(seq * self.mask_prob), seq - 1)
1985
- indices = rand.topk(num_mask, dim = -1).indices
1986
- mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
1987
- kwargs.update(self_attn_context_mask = mask)
1988
-
1989
- logits = self.net(inp, **kwargs)
1990
-
1991
- acc = self.compute_accuracy(logits, target)
1992
-
1993
- loss = F.cross_entropy(
1994
- rearrange(logits, 'b n c -> b c n'),
1995
- target,
1996
- ignore_index = ignore_index
1997
- )
1998
-
1999
- return loss, acc
2000
-
2001
- #===================================================================================================================