Markus28 commited on
Commit
bfc0b2d
·
1 Parent(s): 953f39e

fix: always use flash attention

Browse files
Files changed (1) hide show
  1. modeling_bert.py +1 -3
modeling_bert.py CHANGED
@@ -154,7 +154,6 @@ def _init_weights(module, initializer_range=0.02):
154
  class BertEncoder(nn.Module):
155
  def __init__(self, config: JinaBertConfig):
156
  super().__init__()
157
- self.use_flash_attn = getattr(config, "use_flash_attn", False)
158
  self.layers = nn.ModuleList(
159
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
160
  )
@@ -164,13 +163,12 @@ class BertEncoder(nn.Module):
164
  This means that we only compute the last layer output for these tokens.
165
  subset_mask: (batch, seqlen), dtype=torch.bool
166
  """
167
- if key_padding_mask is None or not self.use_flash_attn:
168
  mixer_kwargs = (
169
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
170
  )
171
  for layer in self.layers:
172
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
173
- print(hidden_states)
174
  if subset_mask is not None:
175
  hidden_states = hidden_states[subset_mask]
176
  else:
 
154
  class BertEncoder(nn.Module):
155
  def __init__(self, config: JinaBertConfig):
156
  super().__init__()
 
157
  self.layers = nn.ModuleList(
158
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
159
  )
 
163
  This means that we only compute the last layer output for these tokens.
164
  subset_mask: (batch, seqlen), dtype=torch.bool
165
  """
166
+ if key_padding_mask is None:
167
  mixer_kwargs = (
168
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
169
  )
170
  for layer in self.layers:
171
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
 
172
  if subset_mask is not None:
173
  hidden_states = hidden_states[subset_mask]
174
  else: