Markus28 commited on
Commit
463061d
1 Parent(s): 2e69073

feat: added option for QK normalization

Browse files
Files changed (2) hide show
  1. configuration_bert.py +2 -0
  2. modeling_bert.py +5 -3
configuration_bert.py CHANGED
@@ -83,6 +83,7 @@ class JinaBertConfig(PretrainedConfig):
83
  pad_vocab_size_multiple=1,
84
  num_tasks=0,
85
  use_flash_attn=True,
 
86
  **kwargs,
87
  ):
88
  assert 'position_embedding_type' not in kwargs
@@ -110,3 +111,4 @@ class JinaBertConfig(PretrainedConfig):
110
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
111
  self.num_tasks = num_tasks
112
  self.use_flash_attn = use_flash_attn
 
 
83
  pad_vocab_size_multiple=1,
84
  num_tasks=0,
85
  use_flash_attn=True,
86
+ use_qk_norm=True,
87
  **kwargs,
88
  ):
89
  assert 'position_embedding_type' not in kwargs
 
111
  self.pad_vocab_size_multiple = pad_vocab_size_multiple
112
  self.num_tasks = num_tasks
113
  self.use_flash_attn = use_flash_attn
114
+ self.use_qk_norm = use_qk_norm
modeling_bert.py CHANGED
@@ -59,9 +59,10 @@ logger = logging.getLogger(__name__)
59
 
60
 
61
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
62
- use_flash_attn = getattr(config, "use_flash_attn", False)
63
- fused_bias_fc = getattr(config, "fused_bias_fc", False)
64
- window_size = getattr(config, "window_size", (-1, -1))
 
65
  mixer_cls = partial(
66
  MHA,
67
  num_heads=config.num_attention_heads,
@@ -73,6 +74,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
73
  return_residual=return_residual,
74
  use_alibi=True,
75
  window_size=window_size,
 
76
  )
77
  return mixer_cls
78
 
 
59
 
60
 
61
  def create_mixer_cls(config, cross_attn=False, return_residual=False):
62
+ use_flash_attn = config.use_flash_attn
63
+ use_qk_norm = config.use_qk_norm
64
+ fused_bias_fc = config.fused_bias_fc
65
+ window_size = config.window_size
66
  mixer_cls = partial(
67
  MHA,
68
  num_heads=config.num_attention_heads,
 
74
  return_residual=return_residual,
75
  use_alibi=True,
76
  window_size=window_size,
77
+ qk_norm=use_qk_norm
78
  )
79
  return mixer_cls
80