fix: always use flash attention
Browse files- modeling_bert.py +1 -3
modeling_bert.py
CHANGED
@@ -154,7 +154,6 @@ def _init_weights(module, initializer_range=0.02):
|
|
154 |
class BertEncoder(nn.Module):
|
155 |
def __init__(self, config: JinaBertConfig):
|
156 |
super().__init__()
|
157 |
-
self.use_flash_attn = getattr(config, "use_flash_attn", False)
|
158 |
self.layers = nn.ModuleList(
|
159 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
160 |
)
|
@@ -164,13 +163,12 @@ class BertEncoder(nn.Module):
|
|
164 |
This means that we only compute the last layer output for these tokens.
|
165 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
166 |
"""
|
167 |
-
if key_padding_mask is None
|
168 |
mixer_kwargs = (
|
169 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
170 |
)
|
171 |
for layer in self.layers:
|
172 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
173 |
-
print(hidden_states)
|
174 |
if subset_mask is not None:
|
175 |
hidden_states = hidden_states[subset_mask]
|
176 |
else:
|
|
|
154 |
class BertEncoder(nn.Module):
|
155 |
def __init__(self, config: JinaBertConfig):
|
156 |
super().__init__()
|
|
|
157 |
self.layers = nn.ModuleList(
|
158 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
159 |
)
|
|
|
163 |
This means that we only compute the last layer output for these tokens.
|
164 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
165 |
"""
|
166 |
+
if key_padding_mask is None:
|
167 |
mixer_kwargs = (
|
168 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
169 |
)
|
170 |
for layer in self.layers:
|
171 |
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
|
172 |
if subset_mask is not None:
|
173 |
hidden_states = hidden_states[subset_mask]
|
174 |
else:
|