zxdu20 fzhang commited on
Commit
551a50e
1 Parent(s): 23ad39b

fix typo in use_gmask (#21)

Browse files

- fix typo in use_gmask (d6504255afdd555d12137fc3af04646f099b5785)


Co-authored-by: Fan Zhang <fzhang@users.noreply.huggingface.co>

Files changed (1) hide show
  1. modeling_chatglm.py +2 -2
modeling_chatglm.py CHANGED
@@ -923,7 +923,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
923
  if position_ids is None:
924
  MASK, gMASK = 150000, 150001
925
  mask_token = MASK if MASK in input_ids else gMASK
926
- use_gmask = False if MASK in input_ids else gMASK
927
 
928
  mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
929
  position_ids = self.get_position_ids(
@@ -1086,7 +1086,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1086
  batch_size, seq_length = input_ids.shape
1087
  MASK, gMASK = 150000, 150001
1088
  mask_token = MASK if MASK in input_ids else gMASK
1089
- use_gmask = False if MASK in input_ids else gMASK
1090
  seqs = input_ids.tolist()
1091
  mask_positions = [seq.index(mask_token) for seq in seqs]
1092
 
 
923
  if position_ids is None:
924
  MASK, gMASK = 150000, 150001
925
  mask_token = MASK if MASK in input_ids else gMASK
926
+ use_gmask = False if MASK in input_ids else True
927
 
928
  mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
929
  position_ids = self.get_position_ids(
 
1086
  batch_size, seq_length = input_ids.shape
1087
  MASK, gMASK = 150000, 150001
1088
  mask_token = MASK if MASK in input_ids else gMASK
1089
+ use_gmask = False if MASK in input_ids else True
1090
  seqs = input_ids.tolist()
1091
  mask_positions = [seq.index(mask_token) for seq in seqs]
1092