feat: added option for QK normalization
Browse files- configuration_bert.py +2 -0
- modeling_bert.py +5 -3
configuration_bert.py
CHANGED
@@ -83,6 +83,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
83 |
pad_vocab_size_multiple=1,
|
84 |
num_tasks=0,
|
85 |
use_flash_attn=True,
|
|
|
86 |
**kwargs,
|
87 |
):
|
88 |
assert 'position_embedding_type' not in kwargs
|
@@ -110,3 +111,4 @@ class JinaBertConfig(PretrainedConfig):
|
|
110 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
111 |
self.num_tasks = num_tasks
|
112 |
self.use_flash_attn = use_flash_attn
|
|
|
|
83 |
pad_vocab_size_multiple=1,
|
84 |
num_tasks=0,
|
85 |
use_flash_attn=True,
|
86 |
+
use_qk_norm=True,
|
87 |
**kwargs,
|
88 |
):
|
89 |
assert 'position_embedding_type' not in kwargs
|
|
|
111 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
112 |
self.num_tasks = num_tasks
|
113 |
self.use_flash_attn = use_flash_attn
|
114 |
+
self.use_qk_norm = use_qk_norm
|
modeling_bert.py
CHANGED
@@ -59,9 +59,10 @@ logger = logging.getLogger(__name__)
|
|
59 |
|
60 |
|
61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
62 |
-
use_flash_attn =
|
63 |
-
|
64 |
-
|
|
|
65 |
mixer_cls = partial(
|
66 |
MHA,
|
67 |
num_heads=config.num_attention_heads,
|
@@ -73,6 +74,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
73 |
return_residual=return_residual,
|
74 |
use_alibi=True,
|
75 |
window_size=window_size,
|
|
|
76 |
)
|
77 |
return mixer_cls
|
78 |
|
|
|
59 |
|
60 |
|
61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
62 |
+
use_flash_attn = config.use_flash_attn
|
63 |
+
use_qk_norm = config.use_qk_norm
|
64 |
+
fused_bias_fc = config.fused_bias_fc
|
65 |
+
window_size = config.window_size
|
66 |
mixer_cls = partial(
|
67 |
MHA,
|
68 |
num_heads=config.num_attention_heads,
|
|
|
74 |
return_residual=return_residual,
|
75 |
use_alibi=True,
|
76 |
window_size=window_size,
|
77 |
+
qk_norm=use_qk_norm
|
78 |
)
|
79 |
return mixer_cls
|
80 |
|