Markus28 commited on
Commit
a2c07ba
·
1 Parent(s): bfc0b2d

feat: added back option to disable flash attention

Browse files
Files changed (2) hide show
  1. configuration_bert.py +3 -0
  2. modeling_bert.py +4 -2
configuration_bert.py CHANGED
@@ -57,6 +57,7 @@ class JinaBertConfig(PretrainedConfig):
57
  layer_norm_eps (`float`, *optional*, defaults to 1e-12):
58
  The epsilon used by the layer normalization layers.
59
  window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
 
60
  """
61
 
62
  model_type = "bert"
@@ -76,6 +77,7 @@ class JinaBertConfig(PretrainedConfig):
76
  layer_norm_eps=1e-12,
77
  pad_token_id=0,
78
  window_size=(-1, -1),
 
79
  **kwargs,
80
  ):
81
  super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -92,4 +94,5 @@ class JinaBertConfig(PretrainedConfig):
92
  self.initializer_range = initializer_range
93
  self.layer_norm_eps = layer_norm_eps
94
  self.window_size = window_size
 
95
 
 
57
  layer_norm_eps (`float`, *optional*, defaults to 1e-12):
58
  The epsilon used by the layer normalization layers.
59
  window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
60
+ use_flash_attn (`bool`, *optional*, defaults to `True`): Whether or not to use flash attention
61
  """
62
 
63
  model_type = "bert"
 
77
  layer_norm_eps=1e-12,
78
  pad_token_id=0,
79
  window_size=(-1, -1),
80
+ use_flash_attn=True,
81
  **kwargs,
82
  ):
83
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
94
  self.initializer_range = initializer_range
95
  self.layer_norm_eps = layer_norm_eps
96
  self.window_size = window_size
97
+ self.use_flash_attn = use_flash_attn
98
 
modeling_bert.py CHANGED
@@ -62,6 +62,7 @@ logger = logging.getLogger(__name__)
62
 
63
 
64
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
 
65
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
66
  window_size = getattr(config, "window_size", (-1, -1))
67
  mixer_cls = partial(
@@ -71,7 +72,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
71
  dropout=config.attention_probs_dropout_prob,
72
  causal=False,
73
  fused_bias_fc=fused_bias_fc,
74
- use_flash_attn=True,
75
  return_residual=return_residual,
76
  use_alibi=True,
77
  window_size=window_size,
@@ -154,6 +155,7 @@ def _init_weights(module, initializer_range=0.02):
154
  class BertEncoder(nn.Module):
155
  def __init__(self, config: JinaBertConfig):
156
  super().__init__()
 
157
  self.layers = nn.ModuleList(
158
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
159
  )
@@ -163,7 +165,7 @@ class BertEncoder(nn.Module):
163
  This means that we only compute the last layer output for these tokens.
164
  subset_mask: (batch, seqlen), dtype=torch.bool
165
  """
166
- if key_padding_mask is None:
167
  mixer_kwargs = (
168
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
169
  )
 
62
 
63
 
64
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
65
+ use_flash_attn = getattr(config, "use_flash_attn", True)
66
  fused_bias_fc = getattr(config, "fused_bias_fc", False)
67
  window_size = getattr(config, "window_size", (-1, -1))
68
  mixer_cls = partial(
 
72
  dropout=config.attention_probs_dropout_prob,
73
  causal=False,
74
  fused_bias_fc=fused_bias_fc,
75
+ use_flash_attn=use_flash_attn,
76
  return_residual=return_residual,
77
  use_alibi=True,
78
  window_size=window_size,
 
155
  class BertEncoder(nn.Module):
156
  def __init__(self, config: JinaBertConfig):
157
  super().__init__()
158
+ self.use_flash_attn = getattr(config, "use_flash_attn", True)
159
  self.layers = nn.ModuleList(
160
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
161
  )
 
165
  This means that we only compute the last layer output for these tokens.
166
  subset_mask: (batch, seqlen), dtype=torch.bool
167
  """
168
+ if key_padding_mask is None or not self.use_flash_attn:
169
  mixer_kwargs = (
170
  {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
171
  )