Revert "feat: added back option to disable flash attention"
Browse filesThis reverts commit a2c07ba3266b5f4aff5b5f4de98324d3d69171db.
- configuration_bert.py +0 -3
- modeling_bert.py +2 -4
configuration_bert.py
CHANGED
@@ -57,7 +57,6 @@ class JinaBertConfig(PretrainedConfig):
|
|
57 |
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
58 |
The epsilon used by the layer normalization layers.
|
59 |
window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
|
60 |
-
use_flash_attn (`bool`, *optional*, defaults to `True`): Whether or not to use flash attention
|
61 |
"""
|
62 |
|
63 |
model_type = "bert"
|
@@ -77,7 +76,6 @@ class JinaBertConfig(PretrainedConfig):
|
|
77 |
layer_norm_eps=1e-12,
|
78 |
pad_token_id=0,
|
79 |
window_size=(-1, -1),
|
80 |
-
use_flash_attn=True,
|
81 |
**kwargs,
|
82 |
):
|
83 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
@@ -94,5 +92,4 @@ class JinaBertConfig(PretrainedConfig):
|
|
94 |
self.initializer_range = initializer_range
|
95 |
self.layer_norm_eps = layer_norm_eps
|
96 |
self.window_size = window_size
|
97 |
-
self.use_flash_attn = use_flash_attn
|
98 |
|
|
|
57 |
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
58 |
The epsilon used by the layer normalization layers.
|
59 |
window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
|
|
|
60 |
"""
|
61 |
|
62 |
model_type = "bert"
|
|
|
76 |
layer_norm_eps=1e-12,
|
77 |
pad_token_id=0,
|
78 |
window_size=(-1, -1),
|
|
|
79 |
**kwargs,
|
80 |
):
|
81 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
|
92 |
self.initializer_range = initializer_range
|
93 |
self.layer_norm_eps = layer_norm_eps
|
94 |
self.window_size = window_size
|
|
|
95 |
|
modeling_bert.py
CHANGED
@@ -62,7 +62,6 @@ logger = logging.getLogger(__name__)
|
|
62 |
|
63 |
|
64 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
65 |
-
use_flash_attn = getattr(config, "use_flash_attn", True)
|
66 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
67 |
window_size = getattr(config, "window_size", (-1, -1))
|
68 |
mixer_cls = partial(
|
@@ -72,7 +71,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
72 |
dropout=config.attention_probs_dropout_prob,
|
73 |
causal=False,
|
74 |
fused_bias_fc=fused_bias_fc,
|
75 |
-
use_flash_attn=
|
76 |
return_residual=return_residual,
|
77 |
use_alibi=True,
|
78 |
window_size=window_size,
|
@@ -155,7 +154,6 @@ def _init_weights(module, initializer_range=0.02):
|
|
155 |
class BertEncoder(nn.Module):
|
156 |
def __init__(self, config: JinaBertConfig):
|
157 |
super().__init__()
|
158 |
-
self.use_flash_attn = getattr(config, "use_flash_attn", True)
|
159 |
self.layers = nn.ModuleList(
|
160 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
161 |
)
|
@@ -165,7 +163,7 @@ class BertEncoder(nn.Module):
|
|
165 |
This means that we only compute the last layer output for these tokens.
|
166 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
167 |
"""
|
168 |
-
if key_padding_mask is None
|
169 |
mixer_kwargs = (
|
170 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
171 |
)
|
|
|
62 |
|
63 |
|
64 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
|
65 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
66 |
window_size = getattr(config, "window_size", (-1, -1))
|
67 |
mixer_cls = partial(
|
|
|
71 |
dropout=config.attention_probs_dropout_prob,
|
72 |
causal=False,
|
73 |
fused_bias_fc=fused_bias_fc,
|
74 |
+
use_flash_attn=True,
|
75 |
return_residual=return_residual,
|
76 |
use_alibi=True,
|
77 |
window_size=window_size,
|
|
|
154 |
class BertEncoder(nn.Module):
|
155 |
def __init__(self, config: JinaBertConfig):
|
156 |
super().__init__()
|
|
|
157 |
self.layers = nn.ModuleList(
|
158 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
159 |
)
|
|
|
163 |
This means that we only compute the last layer output for these tokens.
|
164 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
165 |
"""
|
166 |
+
if key_padding_mask is None:
|
167 |
mixer_kwargs = (
|
168 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
169 |
)
|