Support batch training
Browse files- modeling_chatglm.py +26 -23
modeling_chatglm.py
CHANGED
@@ -818,33 +818,37 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
818 |
return past_key_values
|
819 |
|
820 |
@staticmethod
|
821 |
-
def get_masks(self,
|
822 |
-
|
823 |
-
|
824 |
-
attention_mask = torch.ones((
|
825 |
attention_mask.tril_()
|
826 |
-
|
|
|
827 |
attention_mask.unsqueeze_(1)
|
828 |
attention_mask = (attention_mask < 0.5).bool()
|
829 |
|
830 |
return attention_mask
|
831 |
|
832 |
-
def get_position_ids(self,
|
833 |
-
|
|
|
834 |
if self.position_encoding_2d:
|
835 |
-
|
836 |
-
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
837 |
if not gmask:
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
torch.
|
842 |
-
|
843 |
-
|
|
|
|
|
844 |
else:
|
845 |
-
position_ids = torch.arange(
|
846 |
if not gmask:
|
847 |
-
|
|
|
848 |
|
849 |
position_ids = position_ids.unsqueeze(0)
|
850 |
|
@@ -890,16 +894,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
890 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
891 |
else:
|
892 |
past_key_values = tuple([None] * len(self.layers))
|
893 |
-
seq = input_ids[0].tolist()
|
894 |
|
895 |
if attention_mask is None:
|
896 |
attention_mask = self.get_masks(
|
897 |
-
|
898 |
device=input_ids.device
|
899 |
)
|
900 |
|
901 |
if self.pre_seq_len is not None:
|
902 |
-
prefix_attention_mask = torch.ones(1, 1,
|
903 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
904 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
905 |
|
@@ -908,10 +911,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
908 |
mask_token = MASK if MASK in input_ids else gMASK
|
909 |
use_gmask = False if MASK in input_ids else gMASK
|
910 |
|
911 |
-
|
912 |
position_ids = self.get_position_ids(
|
913 |
-
|
914 |
-
|
915 |
device=input_ids.device,
|
916 |
gmask=use_gmask
|
917 |
)
|
|
|
818 |
return past_key_values
|
819 |
|
820 |
@staticmethod
|
821 |
+
def get_masks(self, input_ids, device):
|
822 |
+
batch_size, seq_length = input_ids.shape
|
823 |
+
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
824 |
+
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
825 |
attention_mask.tril_()
|
826 |
+
for i, context_length in enumerate(context_lengths):
|
827 |
+
attention_mask[i, :, :context_length] = 1
|
828 |
attention_mask.unsqueeze_(1)
|
829 |
attention_mask = (attention_mask < 0.5).bool()
|
830 |
|
831 |
return attention_mask
|
832 |
|
833 |
+
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
834 |
+
batch_size, seq_length = input_ids.shape
|
835 |
+
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
836 |
if self.position_encoding_2d:
|
837 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
|
|
838 |
if not gmask:
|
839 |
+
for i, context_length in enumerate(context_lengths):
|
840 |
+
position_ids[i, context_length:] = mask_positions[i]
|
841 |
+
block_position_ids = [torch.cat((
|
842 |
+
torch.zeros(context_length, dtype=torch.long, device=device),
|
843 |
+
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
844 |
+
)) for context_length in context_lengths]
|
845 |
+
block_position_ids = torch.stack(block_position_ids, dim=0)
|
846 |
+
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
847 |
else:
|
848 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
|
849 |
if not gmask:
|
850 |
+
for i, context_length in enumerate(context_lengths):
|
851 |
+
position_ids[context_length:] = mask_positions[i]
|
852 |
|
853 |
position_ids = position_ids.unsqueeze(0)
|
854 |
|
|
|
894 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
895 |
else:
|
896 |
past_key_values = tuple([None] * len(self.layers))
|
|
|
897 |
|
898 |
if attention_mask is None:
|
899 |
attention_mask = self.get_masks(
|
900 |
+
input_ids,
|
901 |
device=input_ids.device
|
902 |
)
|
903 |
|
904 |
if self.pre_seq_len is not None:
|
905 |
+
prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
|
906 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
907 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
908 |
|
|
|
911 |
mask_token = MASK if MASK in input_ids else gMASK
|
912 |
use_gmask = False if MASK in input_ids else gMASK
|
913 |
|
914 |
+
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
915 |
position_ids = self.get_position_ids(
|
916 |
+
input_ids,
|
917 |
+
mask_positions=mask_positions,
|
918 |
device=input_ids.device,
|
919 |
gmask=use_gmask
|
920 |
)
|