Fix skip connection
Browse files
decoder_only_t5/modeling.py
CHANGED
@@ -532,9 +532,9 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
|
|
532 |
|
533 |
if self.parallel_layers:
|
534 |
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
|
535 |
-
|
536 |
-
|
537 |
-
hidden_states = hidden_states + self.layer[0].dropout(
|
538 |
else:
|
539 |
hidden_states = ff_layer(x)
|
540 |
|
|
|
532 |
|
533 |
if self.parallel_layers:
|
534 |
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
|
535 |
+
x = x + ff_output
|
536 |
+
x *= 2**-0.5
|
537 |
+
hidden_states = hidden_states + self.layer[0].dropout(x)
|
538 |
else:
|
539 |
hidden_states = ff_layer(x)
|
540 |
|