NohTow commited on
Commit
e800526
·
1 Parent(s): 6088290

fix bert_padding

Browse files
__pycache__/attention.cpython-311.pyc CHANGED
Binary files a/__pycache__/attention.cpython-311.pyc and b/__pycache__/attention.cpython-311.pyc differ
 
__pycache__/layers.cpython-311.pyc CHANGED
Binary files a/__pycache__/layers.cpython-311.pyc and b/__pycache__/layers.cpython-311.pyc differ
 
attention.py CHANGED
@@ -24,7 +24,7 @@ import sys
24
  import os
25
  # Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
26
  sys.path.append(os.path.dirname(os.path.realpath(__file__)))
27
- import bert_padding
28
  from .configuration_bert import FlexBertConfig, maybe_add_padding
29
  from .normalization import get_norm_layer
30
  from .initialization import ModuleType, init_weights
@@ -161,7 +161,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
161
  alibi_slopes=slopes,
162
  )
163
  else:
164
- qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
165
  unpad_bs, *_ = qkv.shape
166
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
167
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
@@ -174,7 +174,7 @@ class BertAlibiUnpadSelfAttention(nn.Module):
174
  attention_probs = self.dropout(attention_probs)
175
  attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
176
 
177
- attention = bert_padding.unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
178
 
179
  return attention.view(bs, dim)
180
 
@@ -240,8 +240,8 @@ class BertAlibiUnpadAttention(nn.Module):
240
  self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
241
  if subset_idx is not None:
242
  return self.output(
243
- bert_padding.index_first_axis(self_output, subset_idx),
244
- bert_padding.index_first_axis(input_tensor, subset_idx),
245
  )
246
  else:
247
  return self.output(self_output, input_tensor)
@@ -415,7 +415,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
415
  )
416
  attn = attn.view(bs, dim)
417
  else:
418
- qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
419
  unpad_bs, seqlen, _ = qkv.shape
420
 
421
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
@@ -430,7 +430,7 @@ class FlexBertUnpadAttention(FlexBertAttentionBase):
430
  else None,
431
  )
432
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
433
- attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
434
 
435
  return self.out_drop(self.Wo(attn))
436
 
@@ -565,7 +565,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
565
  )
566
  attn = attn.view(bs, dim)
567
  else:
568
- qkv = bert_padding.pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
569
  unpad_bs, seqlen, _ = qkv.shape
570
 
571
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
@@ -580,7 +580,7 @@ class FlexBertUnpadParallelAttention(FlexBertAttentionBase):
580
  else None,
581
  )
582
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
583
- attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
584
 
585
  return self.out_drop(self.Wo(attn.view(bs, dim)))
586
 
@@ -913,7 +913,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
913
  )
914
  attn = attn.view(bs, dim)
915
  else:
916
- qkv = bert_padding.pad_input(
917
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
918
  ) # batch, max_seqlen, thd
919
  unpad_bs, seqlen, *_ = qkv.shape
@@ -929,7 +929,7 @@ class FlexBertUnpadRopeAttention(FlexBertAttentionBase):
929
  else None,
930
  )
931
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
932
- attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
933
 
934
  return self.out_drop(self.Wo(attn))
935
 
@@ -1244,7 +1244,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
1244
  )
1245
  attn = attn.view(bs, dim)
1246
  else:
1247
- qkv = bert_padding.pad_input(
1248
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
1249
  ) # batch, max_seqlen, thd
1250
  unpad_bs, seqlen, *_ = qkv.shape
@@ -1260,7 +1260,7 @@ class FlexBertUnpadRopeParallelAttention(FlexBertAttentionBase):
1260
  else None,
1261
  )
1262
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
1263
- attn = bert_padding.unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
1264
 
1265
  return self.out_drop(self.Wo(attn))
1266
 
 
24
  import os
25
  # Add src folder root to path to allow us to use relative imports regardless of what directory the script is run from
26
  sys.path.append(os.path.dirname(os.path.realpath(__file__)))
27
+ from .bert_padding import pad_input, unpad_input_only, index_first_axis
28
  from .configuration_bert import FlexBertConfig, maybe_add_padding
29
  from .normalization import get_norm_layer
30
  from .initialization import ModuleType, init_weights
 
161
  alibi_slopes=slopes,
162
  )
163
  else:
164
+ qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
165
  unpad_bs, *_ = qkv.shape
166
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attention_head_size)
167
  # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
 
174
  attention_probs = self.dropout(attention_probs)
175
  attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d
176
 
177
+ attention = unpad_input_only(attention, torch.squeeze(attn_mask) == 1)
178
 
179
  return attention.view(bs, dim)
180
 
 
240
  self_output = self.self(input_tensor, cu_seqlens, max_s, indices, attn_mask, bias, slopes)
241
  if subset_idx is not None:
242
  return self.output(
243
+ index_first_axis(self_output, subset_idx),
244
+ index_first_axis(input_tensor, subset_idx),
245
  )
246
  else:
247
  return self.output(self_output, input_tensor)
 
415
  )
416
  attn = attn.view(bs, dim)
417
  else:
418
+ qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
419
  unpad_bs, seqlen, _ = qkv.shape
420
 
421
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
 
430
  else None,
431
  )
432
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
433
+ attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
434
 
435
  return self.out_drop(self.Wo(attn))
436
 
 
565
  )
566
  attn = attn.view(bs, dim)
567
  else:
568
+ qkv = pad_input(qkv, indices, cu_seqlens.shape[0] - 1, max_seqlen) # batch, max_seqlen, thd
569
  unpad_bs, seqlen, _ = qkv.shape
570
 
571
  qkv = qkv.view(unpad_bs, -1, 3, self.num_attention_heads, self.attn_head_size)
 
580
  else None,
581
  )
582
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
583
+ attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
584
 
585
  return self.out_drop(self.Wo(attn.view(bs, dim)))
586
 
 
913
  )
914
  attn = attn.view(bs, dim)
915
  else:
916
+ qkv = pad_input(
917
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
918
  ) # batch, max_seqlen, thd
919
  unpad_bs, seqlen, *_ = qkv.shape
 
929
  else None,
930
  )
931
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
932
+ attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
933
 
934
  return self.out_drop(self.Wo(attn))
935
 
 
1244
  )
1245
  attn = attn.view(bs, dim)
1246
  else:
1247
+ qkv = pad_input(
1248
  qkv, indices, cu_seqlens.shape[0] - 1, attn_mask.shape[-1]
1249
  ) # batch, max_seqlen, thd
1250
  unpad_bs, seqlen, *_ = qkv.shape
 
1260
  else None,
1261
  )
1262
  attn = attn.transpose(1, 2).view(unpad_bs, -1, dim) # b s h d
1263
+ attn = unpad_input_only(attn, torch.squeeze(attn_mask) == 1)
1264
 
1265
  return self.out_drop(self.Wo(attn))
1266
 
layers.py CHANGED
@@ -20,7 +20,7 @@ from typing import Optional, Union, List
20
  import torch
21
  import torch.nn as nn
22
 
23
- import bert_padding
24
 
25
  from .activation import get_act_fn
26
  from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
@@ -155,7 +155,7 @@ class BertAlibiEncoder(nn.Module):
155
  # and ntokens_unpad is total number of non-padded tokens.
156
  # Then unpadding performs the following compression of the inputs:
157
  # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
158
- hidden_states, indices, cu_seqlens, _ = bert_padding.unpad_input(hidden_states, attention_mask_bool)
159
 
160
  # Add alibi matrix to extended_attention_mask
161
  if self._current_alibi_size < seqlen:
@@ -190,7 +190,7 @@ class BertAlibiEncoder(nn.Module):
190
  # and ntokens_unpad is total number of non-padded tokens.
191
  # Then padding performs the following de-compression:
192
  # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
193
- hidden_states = bert_padding.pad_input(hidden_states, indices, batch, seqlen)
194
  else:
195
  for i in range(len(self.layer) - 1):
196
  layer_module = self.layer[i]
@@ -636,7 +636,7 @@ class FlexBertUnpadEncoder(FlexBertEncoderBase):
636
  if indices is None and cu_seqlens is None and max_seqlen is None:
637
  attention_mask_bool = attention_mask.bool()
638
  batch, seqlen = hidden_states.shape[:2]
639
- hidden_states, indices, cu_seqlens, max_seqlen = bert_padding.unpad_input(
640
  hidden_states, attention_mask_bool
641
  )
642
 
@@ -649,7 +649,7 @@ class FlexBertUnpadEncoder(FlexBertEncoderBase):
649
  attn_mask=attention_mask,
650
  )
651
 
652
- return bert_padding.pad_input(hidden_states, indices, batch, seqlen)
653
  else:
654
  for layer_module in self.layers:
655
  hidden_states = layer_module(
 
20
  import torch
21
  import torch.nn as nn
22
 
23
+ from .bert_padding import unpad_input, pad_input
24
 
25
  from .activation import get_act_fn
26
  from .attention import FlexBertAttentionBase, BertAlibiUnpadAttention, get_attention_layer
 
155
  # and ntokens_unpad is total number of non-padded tokens.
156
  # Then unpadding performs the following compression of the inputs:
157
  # hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
158
+ hidden_states, indices, cu_seqlens, _ = unpad_input(hidden_states, attention_mask_bool)
159
 
160
  # Add alibi matrix to extended_attention_mask
161
  if self._current_alibi_size < seqlen:
 
190
  # and ntokens_unpad is total number of non-padded tokens.
191
  # Then padding performs the following de-compression:
192
  # hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
193
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
194
  else:
195
  for i in range(len(self.layer) - 1):
196
  layer_module = self.layer[i]
 
636
  if indices is None and cu_seqlens is None and max_seqlen is None:
637
  attention_mask_bool = attention_mask.bool()
638
  batch, seqlen = hidden_states.shape[:2]
639
+ hidden_states, indices, cu_seqlens, max_seqlen = unpad_input(
640
  hidden_states, attention_mask_bool
641
  )
642
 
 
649
  attn_mask=attention_mask,
650
  )
651
 
652
+ return pad_input(hidden_states, indices, batch, seqlen)
653
  else:
654
  for layer_module in self.layers:
655
  hidden_states = layer_module(