Alex Birch commited on
Commit
07e555c
1 Parent(s): 1e53ac9

gradient checkpointing for multi-query attention

Browse files
Files changed (1) hide show
  1. attention.py +64 -5
attention.py CHANGED
@@ -316,7 +316,7 @@ class MultiheadAttention(nn.Module, Attn):
316
  False, # multiquery
317
  )
318
  return custom_forward
319
- attn_out: AttnOutput = checkpoint(
320
  create_custom_forward(self.attn_fn),
321
  query,
322
  key,
@@ -332,7 +332,7 @@ class MultiheadAttention(nn.Module, Attn):
332
  **ckpt_kwargs,
333
  )
334
  else:
335
- attn_out: AttnOutput = self.attn_fn(
336
  query,
337
  key,
338
  value,
@@ -345,7 +345,7 @@ class MultiheadAttention(nn.Module, Attn):
345
  training=self.training,
346
  needs_weights=needs_weights,
347
  )
348
- context, attn_weights = attn_out
349
  return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
350
 
351
  class MultiQueryAttention(nn.Module, Attn):
@@ -413,8 +413,67 @@ class MultiQueryAttention(nn.Module, Attn):
413
  past_key_value = PastKeyValue(key, value)
414
  if attn_bias is not None:
415
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
416
- attn_fn_output: AttnFnOutput = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
417
- context, attn_weights = attn_fn_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
419
 
420
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
 
316
  False, # multiquery
317
  )
318
  return custom_forward
319
+ attn_fn_out: AttnFnOutput = checkpoint(
320
  create_custom_forward(self.attn_fn),
321
  query,
322
  key,
 
332
  **ckpt_kwargs,
333
  )
334
  else:
335
+ attn_fn_out: AttnFnOutput = self.attn_fn(
336
  query,
337
  key,
338
  value,
 
345
  training=self.training,
346
  needs_weights=needs_weights,
347
  )
348
+ context, attn_weights = attn_fn_out
349
  return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
350
 
351
  class MultiQueryAttention(nn.Module, Attn):
 
413
  past_key_value = PastKeyValue(key, value)
414
  if attn_bias is not None:
415
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
416
+ if self.training and self.gradient_checkpointing:
417
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
418
+ def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
419
+ def custom_forward(
420
+ query: torch.Tensor,
421
+ key: torch.Tensor,
422
+ value: torch.Tensor,
423
+ n_heads: int,
424
+ softmax_scale: Optional[float],
425
+ attn_bias: Optional[torch.Tensor],
426
+ key_padding_mask: Optional[torch.ByteTensor],
427
+ is_causal: bool,
428
+ dropout_p: float,
429
+ training: bool,
430
+ needs_weights: bool,
431
+ ):
432
+ return attn_fn(
433
+ query,
434
+ key,
435
+ value,
436
+ n_heads,
437
+ softmax_scale,
438
+ attn_bias,
439
+ key_padding_mask,
440
+ is_causal,
441
+ dropout_p,
442
+ training,
443
+ needs_weights,
444
+ True, # multiquery
445
+ )
446
+ return custom_forward
447
+ attn_fn_out: AttnFnOutput = checkpoint(
448
+ create_custom_forward(self.attn_fn),
449
+ query,
450
+ key,
451
+ value,
452
+ self.n_heads,
453
+ self.softmax_scale,
454
+ attn_bias,
455
+ key_padding_mask,
456
+ is_causal,
457
+ self.attn_dropout_p,
458
+ self.training,
459
+ needs_weights,
460
+ **ckpt_kwargs,
461
+ )
462
+ else:
463
+ attn_fn_out: AttnFnOutput = self.attn_fn(
464
+ query,
465
+ key,
466
+ value,
467
+ self.n_heads,
468
+ softmax_scale=self.softmax_scale,
469
+ attn_bias=attn_bias,
470
+ key_padding_mask=key_padding_mask,
471
+ is_causal=is_causal,
472
+ dropout_p=self.attn_dropout_p,
473
+ training=self.training,
474
+ needs_weights=needs_weights,
475
+ )
476
+ context, attn_weights = attn_fn_out
477
  return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
478
 
479
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):