Fix mistake in parallel layers
Browse files- decoder_only_t5/modeling.py +42 -53
decoder_only_t5/modeling.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
from torch import nn
|
6 |
from torch.nn import CrossEntropyLoss
|
7 |
from transformers.models.t5 import modeling_t5
|
8 |
-
from transformers.modeling_outputs import
|
9 |
from transformers.utils import (
|
10 |
add_start_docstrings_to_model_forward,
|
11 |
logging,
|
@@ -167,22 +167,28 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
|
|
167 |
) # (batch_size, n_heads, seq_length, dim_per_head)
|
168 |
|
169 |
# get key/value states
|
170 |
-
key_states =
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
175 |
)
|
176 |
-
value_states =
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
181 |
)
|
182 |
|
183 |
# compute scores
|
184 |
scores = torch.matmul(
|
185 |
-
query_states,
|
186 |
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
187 |
|
188 |
if position_bias is None:
|
@@ -345,8 +351,9 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
|
|
345 |
|
346 |
ff_layer = self.layer[-1]
|
347 |
if self.parallel_layers:
|
|
|
348 |
x = self.layer[0].layer_norm(hidden_states)
|
349 |
-
ff_output = ff_layer(
|
350 |
else:
|
351 |
x = hidden_states
|
352 |
|
@@ -418,7 +425,7 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
|
|
418 |
attention_outputs = attention_outputs + cross_attention_outputs[2:]
|
419 |
|
420 |
if self.parallel_layers:
|
421 |
-
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#
|
422 |
hidden_states = x + ff_output
|
423 |
hidden_states *= 2**-0.5
|
424 |
hidden_states = hidden_states + self.layer[0].dropout(hidden_states)
|
@@ -508,27 +515,21 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
|
|
508 |
|
509 |
@add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
|
510 |
@replace_return_docstrings(
|
511 |
-
output_type=
|
512 |
)
|
513 |
def forward(
|
514 |
self,
|
515 |
-
|
516 |
attention_mask: Optional[torch.FloatTensor] = None,
|
517 |
-
|
518 |
-
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
519 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
520 |
-
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
521 |
-
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
522 |
-
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
523 |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
524 |
-
|
525 |
-
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
526 |
labels: Optional[torch.LongTensor] = None,
|
527 |
use_cache: Optional[bool] = None,
|
528 |
output_attentions: Optional[bool] = None,
|
529 |
output_hidden_states: Optional[bool] = None,
|
530 |
return_dict: Optional[bool] = None,
|
531 |
-
) -> Union[Tuple
|
532 |
r"""
|
533 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
534 |
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
|
@@ -548,43 +549,31 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
|
|
548 |
if self.model_parallel:
|
549 |
torch.cuda.set_device(self.decoder.first_device)
|
550 |
|
551 |
-
if (
|
552 |
-
labels is not None
|
553 |
-
and decoder_input_ids is None
|
554 |
-
and decoder_inputs_embeds is None
|
555 |
-
):
|
556 |
-
# get decoder inputs from shifting lm labels to the right
|
557 |
-
decoder_input_ids = self._shift_right(labels)
|
558 |
-
|
559 |
# Set device for model parallelism
|
560 |
if self.model_parallel:
|
561 |
torch.cuda.set_device(self.decoder.first_device)
|
562 |
-
if
|
563 |
-
|
564 |
if attention_mask is not None:
|
565 |
attention_mask = attention_mask.to(self.decoder.first_device)
|
566 |
-
if decoder_attention_mask is not None:
|
567 |
-
decoder_attention_mask = decoder_attention_mask.to(
|
568 |
-
self.decoder.first_device
|
569 |
-
)
|
570 |
|
571 |
# Decode
|
572 |
-
|
573 |
-
input_ids=
|
574 |
-
attention_mask=
|
575 |
-
inputs_embeds=
|
576 |
past_key_values=past_key_values,
|
577 |
-
|
578 |
-
encoder_attention_mask=
|
579 |
-
head_mask=
|
580 |
-
cross_attn_head_mask=
|
581 |
use_cache=use_cache,
|
582 |
output_attentions=output_attentions,
|
583 |
output_hidden_states=output_hidden_states,
|
584 |
return_dict=return_dict,
|
585 |
)
|
586 |
|
587 |
-
sequence_output =
|
588 |
|
589 |
# Set device for model parallelism
|
590 |
if self.model_parallel:
|
@@ -608,13 +597,13 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
|
|
608 |
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
609 |
|
610 |
if not return_dict:
|
611 |
-
output = (lm_logits,) +
|
612 |
return ((loss,) + output) if loss is not None else output
|
613 |
|
614 |
-
return
|
615 |
loss=loss,
|
616 |
logits=lm_logits,
|
617 |
-
past_key_values=
|
618 |
-
|
619 |
-
|
620 |
)
|
|
|
5 |
from torch import nn
|
6 |
from torch.nn import CrossEntropyLoss
|
7 |
from transformers.models.t5 import modeling_t5
|
8 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
9 |
from transformers.utils import (
|
10 |
add_start_docstrings_to_model_forward,
|
11 |
logging,
|
|
|
167 |
) # (batch_size, n_heads, seq_length, dim_per_head)
|
168 |
|
169 |
# get key/value states
|
170 |
+
key_states = repeat_kv(
|
171 |
+
project(
|
172 |
+
hidden_states,
|
173 |
+
self.k,
|
174 |
+
key_value_states,
|
175 |
+
past_key_value[0] if past_key_value is not None else None,
|
176 |
+
),
|
177 |
+
self.n_kv_groups,
|
178 |
)
|
179 |
+
value_states = repeat_kv(
|
180 |
+
project(
|
181 |
+
hidden_states,
|
182 |
+
self.v,
|
183 |
+
key_value_states,
|
184 |
+
past_key_value[1] if past_key_value is not None else None,
|
185 |
+
),
|
186 |
+
self.n_kv_groups,
|
187 |
)
|
188 |
|
189 |
# compute scores
|
190 |
scores = torch.matmul(
|
191 |
+
query_states, key_states.transpose(3, 2)
|
192 |
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
193 |
|
194 |
if position_bias is None:
|
|
|
351 |
|
352 |
ff_layer = self.layer[-1]
|
353 |
if self.parallel_layers:
|
354 |
+
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L563-L568
|
355 |
x = self.layer[0].layer_norm(hidden_states)
|
356 |
+
ff_output = ff_layer(x)
|
357 |
else:
|
358 |
x = hidden_states
|
359 |
|
|
|
425 |
attention_outputs = attention_outputs + cross_attention_outputs[2:]
|
426 |
|
427 |
if self.parallel_layers:
|
428 |
+
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
|
429 |
hidden_states = x + ff_output
|
430 |
hidden_states *= 2**-0.5
|
431 |
hidden_states = hidden_states + self.layer[0].dropout(hidden_states)
|
|
|
515 |
|
516 |
@add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
|
517 |
@replace_return_docstrings(
|
518 |
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
519 |
)
|
520 |
def forward(
|
521 |
self,
|
522 |
+
input_ids: Optional[torch.LongTensor] = None,
|
523 |
attention_mask: Optional[torch.FloatTensor] = None,
|
524 |
+
position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
|
|
|
|
|
|
|
525 |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
526 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
527 |
labels: Optional[torch.LongTensor] = None,
|
528 |
use_cache: Optional[bool] = None,
|
529 |
output_attentions: Optional[bool] = None,
|
530 |
output_hidden_states: Optional[bool] = None,
|
531 |
return_dict: Optional[bool] = None,
|
532 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
533 |
r"""
|
534 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
535 |
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
|
|
|
549 |
if self.model_parallel:
|
550 |
torch.cuda.set_device(self.decoder.first_device)
|
551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
552 |
# Set device for model parallelism
|
553 |
if self.model_parallel:
|
554 |
torch.cuda.set_device(self.decoder.first_device)
|
555 |
+
if input_ids is not None:
|
556 |
+
input_ids = input_ids.to(self.decoder.first_device)
|
557 |
if attention_mask is not None:
|
558 |
attention_mask = attention_mask.to(self.decoder.first_device)
|
|
|
|
|
|
|
|
|
559 |
|
560 |
# Decode
|
561 |
+
outputs = self.decoder(
|
562 |
+
input_ids=input_ids,
|
563 |
+
attention_mask=attention_mask,
|
564 |
+
inputs_embeds=inputs_embeds,
|
565 |
past_key_values=past_key_values,
|
566 |
+
encoder_hidden_states=None,
|
567 |
+
encoder_attention_mask=None,
|
568 |
+
head_mask=None,
|
569 |
+
cross_attn_head_mask=None,
|
570 |
use_cache=use_cache,
|
571 |
output_attentions=output_attentions,
|
572 |
output_hidden_states=output_hidden_states,
|
573 |
return_dict=return_dict,
|
574 |
)
|
575 |
|
576 |
+
sequence_output = outputs[0]
|
577 |
|
578 |
# Set device for model parallelism
|
579 |
if self.model_parallel:
|
|
|
597 |
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
598 |
|
599 |
if not return_dict:
|
600 |
+
output = (lm_logits,) + outputs[1:]
|
601 |
return ((loss,) + output) if loss is not None else output
|
602 |
|
603 |
+
return CausalLMOutputWithPast(
|
604 |
loss=loss,
|
605 |
logits=lm_logits,
|
606 |
+
past_key_values=outputs.past_key_values,
|
607 |
+
hidden_states=outputs.hidden_states,
|
608 |
+
attentions=outputs.attentions,
|
609 |
)
|