Fix batch beam search
Browse files- modeling_glm.py +90 -34
modeling_glm.py
CHANGED
@@ -29,13 +29,13 @@ from transformers.utils import (
|
|
29 |
)
|
30 |
from transformers.modeling_outputs import (
|
31 |
BaseModelOutputWithPastAndCrossAttentions,
|
|
|
32 |
SequenceClassifierOutput,
|
33 |
-
ModelOutput
|
34 |
)
|
|
|
35 |
from transformers.modeling_utils import (
|
36 |
PreTrainedModel,
|
37 |
)
|
38 |
-
from transformers.utils import logging
|
39 |
from .configuration_glm import GLMConfig
|
40 |
from torch.nn.parameter import Parameter
|
41 |
|
@@ -781,20 +781,60 @@ class GLMModel(GLMPreTrainedModel):
|
|
781 |
attention_mask = torch.zeros(batch_size)
|
782 |
# Transformer.
|
783 |
transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
|
784 |
-
|
785 |
-
|
786 |
if self.output_predict:
|
787 |
-
|
788 |
-
# logits_parallel = mpu.copy_to_model_parallel_region(
|
789 |
-
# logits)
|
790 |
-
logits = F.linear(logits, self.word_embeddings.weight)
|
791 |
|
792 |
return ModelOutput(
|
|
|
793 |
logits=logits,
|
794 |
-
mems=
|
795 |
)
|
796 |
|
797 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
798 |
@add_start_docstrings(
|
799 |
"""GLM Model transformer with a `language modeling` head on top""",
|
800 |
GLM_START_DOCSTRING,
|
@@ -833,6 +873,16 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
833 |
position_ids = position_ids[:, :, :seq_length]
|
834 |
if attention_mask is not None:
|
835 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
return {
|
837 |
"input_ids": input_ids,
|
838 |
"position_ids": position_ids,
|
@@ -845,10 +895,21 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
845 |
input_ids=None,
|
846 |
position_ids=None,
|
847 |
attention_mask=None,
|
|
|
848 |
mems=None,
|
849 |
**kwargs
|
850 |
):
|
851 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
852 |
|
853 |
|
854 |
@add_start_docstrings(
|
@@ -857,16 +918,19 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
857 |
GLM_START_DOCSTRING,
|
858 |
)
|
859 |
class GLMForSequenceClassification(GLMPreTrainedModel):
|
860 |
-
def __init__(self, config, hidden_dropout=
|
861 |
super().__init__(config)
|
862 |
self.pool_token = config.pool_token
|
863 |
self.glm = GLMModel(config)
|
864 |
self.glm.output_predict = False
|
865 |
self.num_class = num_class
|
866 |
# Multi-choice head.
|
867 |
-
self.
|
868 |
-
|
869 |
-
|
|
|
|
|
|
|
870 |
|
871 |
# Initialize weights and apply final processing
|
872 |
self.post_init()
|
@@ -891,29 +955,21 @@ class GLMForSequenceClassification(GLMPreTrainedModel):
|
|
891 |
input_ids = input_ids.reshape(-1, input_ids.size(-1))
|
892 |
attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])
|
893 |
position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
|
894 |
-
model_out = self.glm
|
895 |
-
outputs, mems = model_out.
|
896 |
-
|
897 |
-
if self.pool_token == 'start':
|
898 |
-
output = outputs[
|
899 |
-
torch.arange(outputs.size(0), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask]
|
900 |
-
elif self.pool_token == 'pad':
|
901 |
-
output = outputs[torch.arange(outputs.size(0), dtype=attention_mask.dtype,
|
902 |
-
device=attention_mask.device), attention_mask - 1]
|
903 |
-
elif self.pool_token == 'cls':
|
904 |
-
output = outputs[:, 0]
|
905 |
-
else:
|
906 |
-
raise NotImplementedError
|
907 |
|
908 |
-
output =
|
909 |
-
|
910 |
-
|
911 |
-
|
|
|
912 |
if num_choices is not None:
|
913 |
logits = logits.view(-1, num_choices)
|
914 |
-
|
915 |
-
|
|
|
|
|
916 |
# loss = F.cross_entropy(logits.contiguous().float(), labels.long())
|
917 |
return SequenceClassifierOutput(loss=loss,
|
918 |
logits=logits,
|
919 |
-
hidden_states=
|
|
|
29 |
)
|
30 |
from transformers.modeling_outputs import (
|
31 |
BaseModelOutputWithPastAndCrossAttentions,
|
32 |
+
ModelOutput,
|
33 |
SequenceClassifierOutput,
|
|
|
34 |
)
|
35 |
+
|
36 |
from transformers.modeling_utils import (
|
37 |
PreTrainedModel,
|
38 |
)
|
|
|
39 |
from .configuration_glm import GLMConfig
|
40 |
from torch.nn.parameter import Parameter
|
41 |
|
|
|
781 |
attention_mask = torch.zeros(batch_size)
|
782 |
# Transformer.
|
783 |
transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
|
784 |
+
last_hidden_states, mems = transformer_output
|
785 |
+
logits = None
|
786 |
if self.output_predict:
|
787 |
+
logits = F.linear(last_hidden_states, self.word_embeddings.weight)
|
|
|
|
|
|
|
788 |
|
789 |
return ModelOutput(
|
790 |
+
last_hidden_states=last_hidden_states,
|
791 |
logits=logits,
|
792 |
+
mems=mems,
|
793 |
)
|
794 |
|
795 |
|
796 |
+
@add_start_docstrings(
|
797 |
+
"""GLM Model transformer for multiple choice classification""",
|
798 |
+
GLM_START_DOCSTRING
|
799 |
+
)
|
800 |
+
class GLMForMultipleChoice(GLMPreTrainedModel):
|
801 |
+
def __init__(self, config):
|
802 |
+
super().__init__(config)
|
803 |
+
self.glm = GLMModel(config)
|
804 |
+
self.post_init()
|
805 |
+
|
806 |
+
def forward(
|
807 |
+
self,
|
808 |
+
input_ids=None,
|
809 |
+
position_ids=None,
|
810 |
+
attention_mask=None,
|
811 |
+
choice_ids=None,
|
812 |
+
choice_indices=None,
|
813 |
+
labels=None,
|
814 |
+
mems=None,
|
815 |
+
**kwargs
|
816 |
+
):
|
817 |
+
model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
|
818 |
+
lm_logits = model_output.logits
|
819 |
+
log_probs = []
|
820 |
+
for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
|
821 |
+
log_probs_single = []
|
822 |
+
for choice, choice_target_id in zip(choices, choice_index):
|
823 |
+
tmp = output[choice_target_id, choice]
|
824 |
+
log_probs_single.append(tmp.sum())
|
825 |
+
log_probs.append(torch.stack(log_probs_single))
|
826 |
+
log_probs = torch.stack(log_probs)
|
827 |
+
loss = None
|
828 |
+
if labels is not None:
|
829 |
+
loss_fct = CrossEntropyLoss()
|
830 |
+
loss = loss_fct(log_probs, labels)
|
831 |
+
return ModelOutput(
|
832 |
+
loss=loss,
|
833 |
+
logits=log_probs,
|
834 |
+
lm_logits=lm_logits,
|
835 |
+
mems=model_output.mems
|
836 |
+
)
|
837 |
+
|
838 |
@add_start_docstrings(
|
839 |
"""GLM Model transformer with a `language modeling` head on top""",
|
840 |
GLM_START_DOCSTRING,
|
|
|
873 |
position_ids = position_ids[:, :, :seq_length]
|
874 |
if attention_mask is not None:
|
875 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
876 |
+
if position_ids is not None and input_ids.size(0) > position_ids.size(0):
|
877 |
+
batch_size = position_ids.size(0)
|
878 |
+
num_beams = input_ids.size(0) // batch_size
|
879 |
+
position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)
|
880 |
+
position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])
|
881 |
+
if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):
|
882 |
+
batch_size = attention_mask.size(0)
|
883 |
+
num_beams = input_ids.size(0) // batch_size
|
884 |
+
attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)
|
885 |
+
attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])
|
886 |
return {
|
887 |
"input_ids": input_ids,
|
888 |
"position_ids": position_ids,
|
|
|
895 |
input_ids=None,
|
896 |
position_ids=None,
|
897 |
attention_mask=None,
|
898 |
+
labels=None,
|
899 |
mems=None,
|
900 |
**kwargs
|
901 |
):
|
902 |
+
model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
|
903 |
+
lm_logits = model_output.logits
|
904 |
+
loss = None
|
905 |
+
if labels is not None:
|
906 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
907 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
908 |
+
return ModelOutput(
|
909 |
+
loss=loss,
|
910 |
+
logits=lm_logits,
|
911 |
+
mems=model_output.mems
|
912 |
+
)
|
913 |
|
914 |
|
915 |
@add_start_docstrings(
|
|
|
918 |
GLM_START_DOCSTRING,
|
919 |
)
|
920 |
class GLMForSequenceClassification(GLMPreTrainedModel):
|
921 |
+
def __init__(self, config: GLMConfig, hidden_dropout=None, num_class=1):
|
922 |
super().__init__(config)
|
923 |
self.pool_token = config.pool_token
|
924 |
self.glm = GLMModel(config)
|
925 |
self.glm.output_predict = False
|
926 |
self.num_class = num_class
|
927 |
# Multi-choice head.
|
928 |
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
929 |
+
classifier_dropout = (
|
930 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.output_dropout_prob
|
931 |
+
)
|
932 |
+
self.dropout = torch.nn.Dropout(classifier_dropout)
|
933 |
+
self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
|
934 |
|
935 |
# Initialize weights and apply final processing
|
936 |
self.post_init()
|
|
|
955 |
input_ids = input_ids.reshape(-1, input_ids.size(-1))
|
956 |
attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])
|
957 |
position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
|
958 |
+
model_out = self.glm(input_ids, position_ids, attention_mask)
|
959 |
+
outputs, mems = model_out.last_hidden_states, model_out.mems
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
960 |
|
961 |
+
output = outputs[:, 0, :]
|
962 |
+
output = self.dropout(output)
|
963 |
+
output = torch.tanh(self.dense(output))
|
964 |
+
output = self.dropout(output)
|
965 |
+
logits = self.out_proj(output)
|
966 |
if num_choices is not None:
|
967 |
logits = logits.view(-1, num_choices)
|
968 |
+
loss = None
|
969 |
+
if labels is not None:
|
970 |
+
loss_fct = CrossEntropyLoss()
|
971 |
+
loss = loss_fct(logits, labels)
|
972 |
# loss = F.cross_entropy(logits.contiguous().float(), labels.long())
|
973 |
return SequenceClassifierOutput(loss=loss,
|
974 |
logits=logits,
|
975 |
+
hidden_states=outputs)
|