Alex Birch commited on
Commit
9f0a20b
1 Parent(s): 512b004

add support for AutoModelForCausalLM#from_pretrained()'s device_map='auto'. support gradient checkpointing, probably. add lots of type hints so I could understand what's going on. multiline long method signatures/calls (for easier comparison between checkpointed/non-checkpointed variants, and because these lines got even longer when I added type hints). make MPTForCausalLM#forward accept additional kwargs, since PeftModelForCausalLM#forward tries to send it an argument inputs_embeds=None, which it didn't like too much.

Browse files
Files changed (4) hide show
  1. attention.py +195 -18
  2. blocks.py +9 -4
  3. is_torch_version.py +56 -0
  4. modeling_mpt.py +68 -5
attention.py CHANGED
@@ -1,13 +1,72 @@
1
  """Attention layers."""
2
  import math
3
  import warnings
4
- from typing import Optional
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
 
10
  from .norm import LPLayerNorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
13
  if original_is_causal and num_query_tokens != num_key_tokens:
@@ -17,7 +76,20 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
17
  return False
18
  return original_is_causal
19
 
20
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
  k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
23
  v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
@@ -33,7 +105,7 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
33
  attn_weight = attn_weight + attn_bias
34
  if key_padding_mask is not None:
35
  if attn_bias is not None:
36
- warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
37
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
38
  if is_causal:
39
  s = max(s_q, s_k)
@@ -50,7 +122,7 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
50
  out = rearrange(out, 'b h s d -> b s (h d)')
51
  if needs_weights:
52
  return (out, attn_weight)
53
- return (out, None)
54
 
55
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
56
  for tensor in tensors:
@@ -59,7 +131,20 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
59
  if not tensor.is_cuda:
60
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
61
 
62
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  try:
64
  from flash_attn import bert_padding, flash_attn_interface
65
  except:
@@ -84,9 +169,22 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
84
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
85
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
86
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
87
- return (output, None)
88
 
89
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  try:
91
  from .flash_attn_triton import flash_attn_func
92
  except:
@@ -119,14 +217,16 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
119
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
120
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
121
  output = attn_output.view(*attn_output.shape[:2], -1)
122
- return (output, None)
123
 
124
- class MultiheadAttention(nn.Module):
125
  """Multi-head self attention.
126
 
127
  Using torch or triton attention implemetation enables user to also use
128
  additive bias.
129
  """
 
 
130
 
131
  def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
132
  super().__init__()
@@ -160,7 +260,15 @@ class MultiheadAttention(nn.Module):
160
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
161
  self.out_proj._is_residual = True
162
 
163
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
164
  qkv = self.Wqkv(x)
165
  if self.clip_qkv:
166
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@@ -174,13 +282,73 @@ class MultiheadAttention(nn.Module):
174
  if len(past_key_value) != 0:
175
  key = torch.cat([past_key_value[0], key], dim=1)
176
  value = torch.cat([past_key_value[1], value], dim=1)
177
- past_key_value = (key, value)
178
  if attn_bias is not None:
179
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
180
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
181
- return (self.out_proj(context), attn_weights, past_key_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- class MultiQueryAttention(nn.Module):
184
  """Multi-Query self attention.
185
 
186
  Using torch or triton attention implemetation enables user to also use
@@ -220,7 +388,15 @@ class MultiQueryAttention(nn.Module):
220
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
221
  self.out_proj._is_residual = True
222
 
223
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
224
  qkv = self.Wqkv(x)
225
  if self.clip_qkv:
226
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@@ -234,11 +410,12 @@ class MultiQueryAttention(nn.Module):
234
  if len(past_key_value) != 0:
235
  key = torch.cat([past_key_value[0], key], dim=1)
236
  value = torch.cat([past_key_value[1], value], dim=1)
237
- past_key_value = (key, value)
238
  if attn_bias is not None:
239
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
240
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
241
- return (self.out_proj(context), attn_weights, past_key_value)
 
242
 
243
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
244
  if attn_impl == 'flash':
 
1
  """Attention layers."""
2
  import math
3
  import warnings
4
+ from typing import Optional, Dict, Any, NamedTuple, Protocol, Tuple, Union
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
10
+ from torch.utils.checkpoint import checkpoint
11
  from .norm import LPLayerNorm
12
+ from .is_torch_version import is_torch_version
13
+
14
+ class PastKeyValue(NamedTuple):
15
+ key: torch.Tensor
16
+ value: torch.Tensor
17
+
18
+ class AttnFnOutput(NamedTuple):
19
+ attns: torch.Tensor
20
+ attn_probs: Optional[torch.Tensor]
21
+
22
+ class AttnFn(Protocol):
23
+ def __call__(
24
+ self,
25
+ query: torch.Tensor,
26
+ key: torch.Tensor,
27
+ value: torch.Tensor,
28
+ n_heads: int,
29
+ softmax_scale: Optional[float] = None,
30
+ attn_bias: Optional[torch.Tensor] = None,
31
+ key_padding_mask: Optional[torch.ByteTensor] = None,
32
+ is_causal = False,
33
+ dropout_p = 0.0,
34
+ training = False,
35
+ needs_weights = False,
36
+ multiquery = False,
37
+ ) -> AttnFnOutput: ...
38
+
39
+ class AttnFnCheckpointed(Protocol):
40
+ def __call__(
41
+ self,
42
+ query: torch.Tensor,
43
+ key: torch.Tensor,
44
+ value: torch.Tensor,
45
+ n_heads: int,
46
+ softmax_scale: Optional[float],
47
+ attn_bias: Optional[torch.Tensor],
48
+ key_padding_mask: Optional[torch.ByteTensor],
49
+ is_causal: bool,
50
+ dropout_p: float,
51
+ training: bool,
52
+ needs_weights: bool,
53
+ ) -> AttnFnOutput: ...
54
+
55
+ class AttnOutput(NamedTuple):
56
+ projected_context: torch.Tensor
57
+ attn_weights: Optional[torch.Tensor]
58
+ past_key_value: Union[PastKeyValue, Tuple, None]
59
+
60
+ class Attn(Protocol):
61
+ def __call__(
62
+ self,
63
+ x: torch.Tensor,
64
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
65
+ attn_bias: Optional[torch.Tensor] = None,
66
+ attention_mask: Optional[torch.ByteTensor] = None,
67
+ is_causal = True,
68
+ needs_weights = False,
69
+ ) -> AttnOutput: ...
70
 
71
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
72
  if original_is_causal and num_query_tokens != num_key_tokens:
 
76
  return False
77
  return original_is_causal
78
 
79
+ def scaled_multihead_dot_product_attention(
80
+ query: torch.Tensor,
81
+ key: torch.Tensor,
82
+ value: torch.Tensor,
83
+ n_heads: int,
84
+ softmax_scale: Optional[float] = None,
85
+ attn_bias: Optional[torch.Tensor] = None,
86
+ key_padding_mask: Optional[torch.ByteTensor] = None,
87
+ is_causal = False,
88
+ dropout_p = 0.0,
89
+ training = False,
90
+ needs_weights = False,
91
+ multiquery = False,
92
+ ) -> AttnFnOutput:
93
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
94
  k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
95
  v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
 
105
  attn_weight = attn_weight + attn_bias
106
  if key_padding_mask is not None:
107
  if attn_bias is not None:
108
+ warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
109
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
110
  if is_causal:
111
  s = max(s_q, s_k)
 
122
  out = rearrange(out, 'b h s d -> b s (h d)')
123
  if needs_weights:
124
  return (out, attn_weight)
125
+ return AttnFnOutput(out, None)
126
 
127
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
128
  for tensor in tensors:
 
131
  if not tensor.is_cuda:
132
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
133
 
134
+ def flash_attn_fn(
135
+ query: torch.Tensor,
136
+ key: torch.Tensor,
137
+ value: torch.Tensor,
138
+ n_heads: int,
139
+ softmax_scale: Optional[float] = None,
140
+ attn_bias: Optional[torch.Tensor] = None,
141
+ key_padding_mask: Optional[torch.ByteTensor] = None,
142
+ is_causal = False,
143
+ dropout_p = 0.0,
144
+ training = False,
145
+ needs_weights = False,
146
+ multiquery = False,
147
+ ) -> AttnFnOutput:
148
  try:
149
  from flash_attn import bert_padding, flash_attn_interface
150
  except:
 
169
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
170
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
171
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
172
+ return AttnFnOutput(output, None)
173
 
174
+ def triton_flash_attn_fn(
175
+ query: torch.Tensor,
176
+ key: torch.Tensor,
177
+ value: torch.Tensor,
178
+ n_heads: int,
179
+ softmax_scale: Optional[float] = None,
180
+ attn_bias: Optional[torch.Tensor] = None,
181
+ key_padding_mask: Optional[torch.ByteTensor] = None,
182
+ is_causal = False,
183
+ dropout_p = 0.0,
184
+ training = False,
185
+ needs_weights = False,
186
+ multiquery = False,
187
+ ) -> AttnFnOutput:
188
  try:
189
  from .flash_attn_triton import flash_attn_func
190
  except:
 
217
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
218
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
219
  output = attn_output.view(*attn_output.shape[:2], -1)
220
+ return AttnFnOutput(output, None)
221
 
222
+ class MultiheadAttention(nn.Module, Attn):
223
  """Multi-head self attention.
224
 
225
  Using torch or triton attention implemetation enables user to also use
226
  additive bias.
227
  """
228
+ gradient_checkpointing = False
229
+ attn_fn: AttnFn
230
 
231
  def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
232
  super().__init__()
 
260
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
261
  self.out_proj._is_residual = True
262
 
263
+ def forward(
264
+ self,
265
+ x: torch.Tensor,
266
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
267
+ attn_bias: Optional[torch.Tensor] = None,
268
+ attention_mask: Optional[torch.ByteTensor] = None,
269
+ is_causal = True,
270
+ needs_weights = False,
271
+ ) -> AttnOutput:
272
  qkv = self.Wqkv(x)
273
  if self.clip_qkv:
274
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 
282
  if len(past_key_value) != 0:
283
  key = torch.cat([past_key_value[0], key], dim=1)
284
  value = torch.cat([past_key_value[1], value], dim=1)
285
+ past_key_value = PastKeyValue(key, value)
286
  if attn_bias is not None:
287
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
288
+ if self.training and self.gradient_checkpointing:
289
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
290
+ def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
291
+ def custom_forward(
292
+ query: torch.Tensor,
293
+ key: torch.Tensor,
294
+ value: torch.Tensor,
295
+ n_heads: int,
296
+ softmax_scale: Optional[float],
297
+ attn_bias: Optional[torch.Tensor],
298
+ key_padding_mask: Optional[torch.ByteTensor],
299
+ is_causal: bool,
300
+ dropout_p: float,
301
+ training: bool,
302
+ needs_weights: bool,
303
+ ):
304
+ return attn_fn(
305
+ query,
306
+ key,
307
+ value,
308
+ n_heads,
309
+ softmax_scale,
310
+ attn_bias,
311
+ key_padding_mask,
312
+ is_causal,
313
+ dropout_p,
314
+ training,
315
+ needs_weights,
316
+ False, # multiquery
317
+ )
318
+ return custom_forward
319
+ attn_out: AttnOutput = checkpoint(
320
+ create_custom_forward(self.attn_fn),
321
+ query,
322
+ key,
323
+ value,
324
+ self.n_heads,
325
+ self.softmax_scale,
326
+ attn_bias,
327
+ key_padding_mask,
328
+ is_causal,
329
+ self.attn_dropout_p,
330
+ self.training,
331
+ needs_weights,
332
+ **ckpt_kwargs,
333
+ )
334
+ else:
335
+ attn_out: AttnOutput = self.attn_fn(
336
+ query,
337
+ key,
338
+ value,
339
+ self.n_heads,
340
+ softmax_scale=self.softmax_scale,
341
+ attn_bias=attn_bias,
342
+ key_padding_mask=key_padding_mask,
343
+ is_causal=is_causal,
344
+ dropout_p=self.attn_dropout_p,
345
+ training=self.training,
346
+ needs_weights=needs_weights,
347
+ )
348
+ context, attn_weights = attn_out
349
+ return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
350
 
351
+ class MultiQueryAttention(nn.Module, Attn):
352
  """Multi-Query self attention.
353
 
354
  Using torch or triton attention implemetation enables user to also use
 
388
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
389
  self.out_proj._is_residual = True
390
 
391
+ def forward(
392
+ self,
393
+ x: torch.Tensor,
394
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
395
+ attn_bias: Optional[torch.Tensor] = None,
396
+ attention_mask: Optional[torch.ByteTensor] = None,
397
+ is_causal = True,
398
+ needs_weights = False,
399
+ ) -> AttnOutput:
400
  qkv = self.Wqkv(x)
401
  if self.clip_qkv:
402
  qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
 
410
  if len(past_key_value) != 0:
411
  key = torch.cat([past_key_value[0], key], dim=1)
412
  value = torch.cat([past_key_value[1], value], dim=1)
413
+ past_key_value = PastKeyValue(key, value)
414
  if attn_bias is not None:
415
  attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
416
+ attn_fn_output: AttnFnOutput = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
417
+ context, attn_weights = attn_fn_output
418
+ return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
419
 
420
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
421
  if attn_impl == 'flash':
blocks.py CHANGED
@@ -1,10 +1,14 @@
1
  """GPT Blocks used for the GPT Model."""
2
- from typing import Dict, Optional, Tuple
3
  import torch
4
  import torch.nn as nn
5
- from .attention import ATTN_CLASS_REGISTRY
6
  from .norm import NORM_CLASS_REGISTRY
7
 
 
 
 
 
8
  class MPTMLP(nn.Module):
9
 
10
  def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
@@ -18,6 +22,7 @@ class MPTMLP(nn.Module):
18
  return self.down_proj(self.act(self.up_proj(x)))
19
 
20
  class MPTBlock(nn.Module):
 
21
 
22
  def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
  del kwargs
@@ -31,11 +36,11 @@ class MPTBlock(nn.Module):
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
  (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
1
  """GPT Blocks used for the GPT Model."""
2
+ from typing import Dict, Optional, Tuple, NamedTuple, Union
3
  import torch
4
  import torch.nn as nn
5
+ from .attention import ATTN_CLASS_REGISTRY, Attn, PastKeyValue
6
  from .norm import NORM_CLASS_REGISTRY
7
 
8
+ class MPTBlockOutput(NamedTuple):
9
+ hidden_states: torch.Tensor
10
+ past_key_value: Union[PastKeyValue, Tuple, None]
11
+
12
  class MPTMLP(nn.Module):
13
 
14
  def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
 
22
  return self.down_proj(self.act(self.up_proj(x)))
23
 
24
  class MPTBlock(nn.Module):
25
+ attn: Attn
26
 
27
  def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
28
  del kwargs
 
36
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
37
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
38
 
39
+ def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput:
40
  a = self.norm_1(x)
41
  (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
42
  x = x + self.resid_attn_dropout(b)
43
  m = self.norm_2(x)
44
  n = self.ffn(m)
45
  x = x + self.resid_ffn_dropout(n)
46
+ return MPTBlockOutput(x, past_key_value)
is_torch_version.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import operator as op
4
+ from packaging import version
5
+ from packaging.version import Version, parse
6
+ from typing import Union
7
+ import importlib.util
8
+
9
+ # The package importlib_metadata is in a different place, depending on the python version.
10
+ if sys.version_info < (3, 8):
11
+ import importlib_metadata
12
+ else:
13
+ import importlib.metadata as importlib_metadata
14
+
15
+ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ _torch_available = importlib.util.find_spec("torch") is not None
20
+ if _torch_available:
21
+ try:
22
+ _torch_version = importlib_metadata.version("torch")
23
+ logger.info(f"PyTorch version {_torch_version} available.")
24
+ except importlib_metadata.PackageNotFoundError:
25
+ _torch_available = False
26
+
27
+ # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
28
+ def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
29
+ """
30
+ Args:
31
+ Compares a library version to some requirement using a given operation.
32
+ library_or_version (`str` or `packaging.version.Version`):
33
+ A library name or a version to check.
34
+ operation (`str`):
35
+ A string representation of an operator, such as `">"` or `"<="`.
36
+ requirement_version (`str`):
37
+ The version to compare the library version against
38
+ """
39
+ if operation not in STR_OPERATION_TO_FUNC.keys():
40
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
41
+ operation = STR_OPERATION_TO_FUNC[operation]
42
+ if isinstance(library_or_version, str):
43
+ library_or_version = parse(importlib_metadata.version(library_or_version))
44
+ return operation(library_or_version, parse(requirement_version))
45
+
46
+ # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
47
+ def is_torch_version(operation: str, version: str):
48
+ """
49
+ Args:
50
+ Compares the current PyTorch version to a given reference with an operation.
51
+ operation (`str`):
52
+ A string representation of an operator, such as `">"` or `"<="`
53
+ version (`str`):
54
+ A string version of PyTorch
55
+ """
56
+ return compare_versions(parse(_torch_version), operation, version)
modeling_mpt.py CHANGED
@@ -4,25 +4,45 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
  """
5
  import math
6
  import warnings
7
- from typing import List, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
11
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
- from .attention import attn_bias_shape, build_attn_bias
14
- from .blocks import MPTBlock
 
15
  from .norm import NORM_CLASS_REGISTRY
16
  from .configuration_mpt import MPTConfig
17
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
 
 
 
 
 
26
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
@@ -64,6 +84,7 @@ class MPTModel(MPTPreTrainedModel):
64
  if self.config.init_config['verbose'] > 1:
65
  init_fn_name = self.config.init_config['name']
66
  warnings.warn(f'Using {init_fn_name} initialization.')
 
67
 
68
  def get_input_embeddings(self):
69
  return self.wte
@@ -130,6 +151,12 @@ class MPTModel(MPTPreTrainedModel):
130
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
131
  return_dict = return_dict if return_dict is not None else self.config.return_dict
132
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
 
 
 
 
133
  if attention_mask is not None:
134
  attention_mask = attention_mask.bool()
135
  if prefix_mask is not None:
@@ -180,7 +207,43 @@ class MPTModel(MPTPreTrainedModel):
180
  assert all_hidden_states is not None
181
  all_hidden_states = all_hidden_states + (x,)
182
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
183
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
186
  x = self.norm_f(x)
@@ -231,7 +294,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
231
  def get_decoder(self):
232
  return self.transformer
233
 
234
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
235
  return_dict = return_dict if return_dict is not None else self.config.return_dict
236
  use_cache = use_cache if use_cache is not None else self.config.use_cache
237
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
 
4
  """
5
  import math
6
  import warnings
7
+ from typing import Any, List, Optional, Tuple, Union, Protocol, Dict
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.utils import logging
15
+ from .attention import attn_bias_shape, build_attn_bias, PastKeyValue
16
+ from .blocks import MPTBlock, MPTBlockOutput
17
  from .norm import NORM_CLASS_REGISTRY
18
  from .configuration_mpt import MPTConfig
19
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
20
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
21
  from .meta_init_context import init_empty_weights
22
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
23
+ from .is_torch_version import is_torch_version
24
+
25
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
26
 
27
+ logger = logging.get_logger(__name__)
28
+
29
+ class MPTBlockCheckpointedForward(Protocol):
30
+ def __call__(
31
+ x: torch.Tensor,
32
+ past_key_value: Union[PastKeyValue, Tuple, None],
33
+ attn_bias: Optional[torch.Tensor],
34
+ attention_mask: Optional[torch.ByteTensor],
35
+ is_causal: bool,
36
+ ) -> MPTBlockOutput: ...
37
+
38
  class MPTPreTrainedModel(PreTrainedModel):
39
  config_class = MPTConfig
40
  base_model_prefix = 'model'
41
+ _no_split_modules = ['MPTBlock']
42
+ supports_gradient_checkpointing = True
43
+ def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
44
+ if isinstance(module, MPTModel):
45
+ module.gradient_checkpointing = value
46
 
47
  class MPTModel(MPTPreTrainedModel):
48
 
 
84
  if self.config.init_config['verbose'] > 1:
85
  init_fn_name = self.config.init_config['name']
86
  warnings.warn(f'Using {init_fn_name} initialization.')
87
+ self.gradient_checkpointing = False
88
 
89
  def get_input_embeddings(self):
90
  return self.wte
 
151
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
152
  return_dict = return_dict if return_dict is not None else self.config.return_dict
153
  use_cache = use_cache if use_cache is not None else self.config.use_cache
154
+ if self.gradient_checkpointing and self.training:
155
+ if use_cache:
156
+ logger.warning_once(
157
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
158
+ )
159
+ use_cache = False
160
  if attention_mask is not None:
161
  attention_mask = attention_mask.bool()
162
  if prefix_mask is not None:
 
207
  assert all_hidden_states is not None
208
  all_hidden_states = all_hidden_states + (x,)
209
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
210
+ if self.gradient_checkpointing and self.training:
211
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
212
+ def create_custom_forward(module: MPTBlock) -> MPTBlockCheckpointedForward:
213
+ def custom_forward(
214
+ x: torch.Tensor,
215
+ past_key_value: Union[PastKeyValue, Tuple, None],
216
+ attn_bias: Optional[torch.Tensor],
217
+ attention_mask: Optional[torch.ByteTensor],
218
+ is_causal: bool
219
+ ):
220
+ return module.forward(
221
+ x,
222
+ past_key_value,
223
+ attn_bias,
224
+ attention_mask,
225
+ is_causal,
226
+ )
227
+ return custom_forward
228
+ block_out: MPTBlockOutput = checkpoint(
229
+ create_custom_forward(block),
230
+ x,
231
+ past_key_value,
232
+ attn_bias,
233
+ attention_mask,
234
+ self.is_causal,
235
+ **ckpt_kwargs,
236
+ )
237
+ else:
238
+ block_out: MPTBlockOutput = block(
239
+ x,
240
+ past_key_value=past_key_value,
241
+ attn_bias=attn_bias,
242
+ attention_mask=attention_mask,
243
+ is_causal=self.is_causal,
244
+ )
245
+ x, past_key_value = block_out
246
+ del block_out
247
  if past_key_values is not None:
248
  past_key_values[b_idx] = past_key_value
249
  x = self.norm_f(x)
 
294
  def get_decoder(self):
295
  return self.transformer
296
 
297
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, *args, **kwargs):
298
  return_dict = return_dict if return_dict is not None else self.config.return_dict
299
  use_cache = use_cache if use_cache is not None else self.config.use_cache
300
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)