fix: try to skip initialization of task type embeddings
Browse files- modeling_bert.py +1 -1
modeling_bert.py
CHANGED
@@ -145,7 +145,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
145 |
nn.init.normal_(module.weight, std=initializer_range)
|
146 |
if module.bias is not None:
|
147 |
nn.init.zeros_(module.bias)
|
148 |
-
elif isinstance(module, nn.Embedding) and not module
|
149 |
nn.init.normal_(module.weight, std=initializer_range)
|
150 |
if module.padding_idx is not None:
|
151 |
nn.init.zeros_(module.weight[module.padding_idx])
|
|
|
145 |
nn.init.normal_(module.weight, std=initializer_range)
|
146 |
if module.bias is not None:
|
147 |
nn.init.zeros_(module.bias)
|
148 |
+
elif isinstance(module, nn.Embedding) and not getattr(module, "skip_init", False):
|
149 |
nn.init.normal_(module.weight, std=initializer_range)
|
150 |
if module.padding_idx is not None:
|
151 |
nn.init.zeros_(module.weight[module.padding_idx])
|