oweller2
commited on
Commit
•
8686e3f
1
Parent(s):
cf03b9b
added
Browse files- modeling_flexbert.py +9 -3
modeling_flexbert.py
CHANGED
@@ -1536,7 +1536,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1536 |
# Initialize weights and apply final processing
|
1537 |
self._init_weights(reset_params=False)
|
1538 |
|
1539 |
-
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
1540 |
# Handle the XOR condition
|
1541 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1542 |
|
@@ -1556,7 +1556,7 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1556 |
|
1557 |
if not self.config.tie_word_embeddings:
|
1558 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1559 |
-
|
1560 |
@classmethod
|
1561 |
def from_composer(
|
1562 |
cls,
|
@@ -1702,13 +1702,19 @@ class FlexBertForCasualLM(FlexBertPreTrainedModel):
|
|
1702 |
)
|
1703 |
|
1704 |
if self.pad_logits:
|
|
|
|
|
|
|
|
|
1705 |
return CausalLMOutput(
|
1706 |
loss=loss,
|
1707 |
-
logits=
|
1708 |
hidden_states=None,
|
1709 |
attentions=None,
|
1710 |
)
|
1711 |
else:
|
|
|
|
|
1712 |
return CausalLMOutput(
|
1713 |
loss=loss,
|
1714 |
logits=logits,
|
|
|
1536 |
# Initialize weights and apply final processing
|
1537 |
self._init_weights(reset_params=False)
|
1538 |
|
1539 |
+
[] def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
1540 |
# Handle the XOR condition
|
1541 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1542 |
|
|
|
1556 |
|
1557 |
if not self.config.tie_word_embeddings:
|
1558 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1559 |
+
|
1560 |
@classmethod
|
1561 |
def from_composer(
|
1562 |
cls,
|
|
|
1702 |
)
|
1703 |
|
1704 |
if self.pad_logits:
|
1705 |
+
# Reshape logits to 3D if needed
|
1706 |
+
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len)[0]
|
1707 |
+
if len(new_logits.shape) == 2:
|
1708 |
+
new_logits = new_logits.unsqueeze(0)
|
1709 |
return CausalLMOutput(
|
1710 |
loss=loss,
|
1711 |
+
logits=new_logits,
|
1712 |
hidden_states=None,
|
1713 |
attentions=None,
|
1714 |
)
|
1715 |
else:
|
1716 |
+
if len(logits.shape) == 2:
|
1717 |
+
logits = logits.unsqueeze(0)
|
1718 |
return CausalLMOutput(
|
1719 |
loss=loss,
|
1720 |
logits=logits,
|