zxdu20 commited on
Commit
8127ab6
1 Parent(s): fbda120

Support batch training

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +26 -23
modeling_chatglm.py CHANGED
@@ -818,33 +818,37 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
818
  return past_key_values
819
 
820
  @staticmethod
821
- def get_masks(self, seq, device):
822
- context_length = seq.index(self.config.bos_token_id) + 1
823
-
824
- attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
825
  attention_mask.tril_()
826
- attention_mask[..., :context_length - 1] = 1
 
827
  attention_mask.unsqueeze_(1)
828
  attention_mask = (attention_mask < 0.5).bool()
829
 
830
  return attention_mask
831
 
832
- def get_position_ids(self, seq, mask_position, device, gmask=False):
833
- context_length = len(seq)
 
834
  if self.position_encoding_2d:
835
- seq_length = seq.index(self.config.bos_token_id)
836
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
837
  if not gmask:
838
- position_ids[seq_length:] = mask_position
839
- block_position_ids = torch.cat((
840
- torch.zeros(seq_length, dtype=torch.long, device=device),
841
- torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1
842
- ))
843
- position_ids = torch.stack((position_ids, block_position_ids), dim=0)
 
 
844
  else:
845
- position_ids = torch.arange(context_length, dtype=torch.long, device=device)
846
  if not gmask:
847
- position_ids[context_length - 1:] = mask_position
 
848
 
849
  position_ids = position_ids.unsqueeze(0)
850
 
@@ -890,16 +894,15 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
890
  past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
891
  else:
892
  past_key_values = tuple([None] * len(self.layers))
893
- seq = input_ids[0].tolist()
894
 
895
  if attention_mask is None:
896
  attention_mask = self.get_masks(
897
- seq=seq,
898
  device=input_ids.device
899
  )
900
 
901
  if self.pre_seq_len is not None:
902
- prefix_attention_mask = torch.ones(1, 1, len(seq), self.pre_seq_len).to(attention_mask.device)
903
  prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
904
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
905
 
@@ -908,10 +911,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
908
  mask_token = MASK if MASK in input_ids else gMASK
909
  use_gmask = False if MASK in input_ids else gMASK
910
 
911
- mask_position = seq.index(mask_token)
912
  position_ids = self.get_position_ids(
913
- seq=seq,
914
- mask_position=mask_position,
915
  device=input_ids.device,
916
  gmask=use_gmask
917
  )
 
818
  return past_key_values
819
 
820
  @staticmethod
821
+ def get_masks(self, input_ids, device):
822
+ batch_size, seq_length = input_ids.shape
823
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
824
+ attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
825
  attention_mask.tril_()
826
+ for i, context_length in enumerate(context_lengths):
827
+ attention_mask[i, :, :context_length] = 1
828
  attention_mask.unsqueeze_(1)
829
  attention_mask = (attention_mask < 0.5).bool()
830
 
831
  return attention_mask
832
 
833
+ def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
834
+ batch_size, seq_length = input_ids.shape
835
+ context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
836
  if self.position_encoding_2d:
837
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
 
838
  if not gmask:
839
+ for i, context_length in enumerate(context_lengths):
840
+ position_ids[i, context_length:] = mask_positions[i]
841
+ block_position_ids = [torch.cat((
842
+ torch.zeros(context_length, dtype=torch.long, device=device),
843
+ torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
844
+ )) for context_length in context_lengths]
845
+ block_position_ids = torch.stack(block_position_ids, dim=0)
846
+ position_ids = torch.stack((position_ids, block_position_ids), dim=1)
847
  else:
848
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
849
  if not gmask:
850
+ for i, context_length in enumerate(context_lengths):
851
+ position_ids[context_length:] = mask_positions[i]
852
 
853
  position_ids = position_ids.unsqueeze(0)
854
 
 
894
  past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
895
  else:
896
  past_key_values = tuple([None] * len(self.layers))
 
897
 
898
  if attention_mask is None:
899
  attention_mask = self.get_masks(
900
+ input_ids,
901
  device=input_ids.device
902
  )
903
 
904
  if self.pre_seq_len is not None:
905
+ prefix_attention_mask = torch.ones(1, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
906
  prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
907
  attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
908
 
 
911
  mask_token = MASK if MASK in input_ids else gMASK
912
  use_gmask = False if MASK in input_ids else gMASK
913
 
914
+ mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
915
  position_ids = self.get_position_ids(
916
+ input_ids,
917
+ mask_positions=mask_positions,
918
  device=input_ids.device,
919
  gmask=use_gmask
920
  )