Markus28 commited on
Commit
3b35eab
1 Parent(s): 95ca1a8

fix: try to skip initialization of task type embeddings

Browse files
Files changed (1) hide show
  1. modeling_bert.py +1 -1
modeling_bert.py CHANGED
@@ -145,7 +145,7 @@ def _init_weights(module, initializer_range=0.02):
145
  nn.init.normal_(module.weight, std=initializer_range)
146
  if module.bias is not None:
147
  nn.init.zeros_(module.bias)
148
- elif isinstance(module, nn.Embedding) and not module.skip_init:
149
  nn.init.normal_(module.weight, std=initializer_range)
150
  if module.padding_idx is not None:
151
  nn.init.zeros_(module.weight[module.padding_idx])
 
145
  nn.init.normal_(module.weight, std=initializer_range)
146
  if module.bias is not None:
147
  nn.init.zeros_(module.bias)
148
+ elif isinstance(module, nn.Embedding) and not getattr(module, "skip_init", False):
149
  nn.init.normal_(module.weight, std=initializer_range)
150
  if module.padding_idx is not None:
151
  nn.init.zeros_(module.weight[module.padding_idx])