oweller2
commited on
Commit
·
4b203f9
1
Parent(s):
e0229bb
updates
Browse files- config.json +2 -2
- modeling_flexbert.py +13 -3
config.json
CHANGED
@@ -2,12 +2,12 @@
|
|
2 |
"allow_embedding_resizing": true,
|
3 |
"architectures": [
|
4 |
"FlexBertModel",
|
5 |
-
"
|
6 |
],
|
7 |
"auto_map": {
|
8 |
"AutoConfig": "orionweller/test-flex-gpt--configuration_bert.FlexBertConfig",
|
9 |
"AutoModel": "orionweller/test-flex-gpt--modeling_flexbert.FlexBertModel",
|
10 |
-
"AutoModelForCausalLM": "orionweller/test-flex-gpt--modeling_flexbert.
|
11 |
},
|
12 |
"attention_layer": "rope",
|
13 |
"attention_probs_dropout_prob": 0.0,
|
|
|
2 |
"allow_embedding_resizing": true,
|
3 |
"architectures": [
|
4 |
"FlexBertModel",
|
5 |
+
"FlexBertForCausalLM"
|
6 |
],
|
7 |
"auto_map": {
|
8 |
"AutoConfig": "orionweller/test-flex-gpt--configuration_bert.FlexBertConfig",
|
9 |
"AutoModel": "orionweller/test-flex-gpt--modeling_flexbert.FlexBertModel",
|
10 |
+
"AutoModelForCausalLM": "orionweller/test-flex-gpt--modeling_flexbert.FlexBertForCausalLM"
|
11 |
},
|
12 |
"attention_layer": "rope",
|
13 |
"attention_probs_dropout_prob": 0.0,
|
modeling_flexbert.py
CHANGED
@@ -1534,14 +1534,23 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1534 |
self._init_weights(reset_params=False)
|
1535 |
|
1536 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
|
|
1537 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1538 |
-
|
1539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1540 |
else:
|
1541 |
assert isinstance(reset_params, bool)
|
1542 |
self.bert._init_weights(reset_params=reset_params)
|
1543 |
self.lm_head._init_weights(reset_params=reset_params)
|
1544 |
-
|
1545 |
if not self.config.tie_word_embeddings:
|
1546 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1547 |
|
@@ -1742,6 +1751,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1742 |
params += _count_parameters(self.lm_head, trainable)
|
1743 |
return params
|
1744 |
|
|
|
1745 |
|
1746 |
def init_model_from_pretrained(
|
1747 |
pretrained_model: FlexBertModel,
|
|
|
1534 |
self._init_weights(reset_params=False)
|
1535 |
|
1536 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
1537 |
+
# Handle the XOR condition
|
1538 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1539 |
+
|
1540 |
+
if module is not None:
|
1541 |
+
# Add basic initialization for common module types
|
1542 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
1543 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
1544 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
1545 |
+
module.bias.data.zero_()
|
1546 |
+
elif isinstance(module, nn.LayerNorm):
|
1547 |
+
module.bias.data.zero_()
|
1548 |
+
module.weight.data.fill_(1.0)
|
1549 |
else:
|
1550 |
assert isinstance(reset_params, bool)
|
1551 |
self.bert._init_weights(reset_params=reset_params)
|
1552 |
self.lm_head._init_weights(reset_params=reset_params)
|
1553 |
+
|
1554 |
if not self.config.tie_word_embeddings:
|
1555 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1556 |
|
|
|
1751 |
params += _count_parameters(self.lm_head, trainable)
|
1752 |
return params
|
1753 |
|
1754 |
+
FlexBertForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
1755 |
|
1756 |
def init_model_from_pretrained(
|
1757 |
pretrained_model: FlexBertModel,
|