Use gmask in first place
Browse files- modeling_chatglm.py +4 -4
modeling_chatglm.py
CHANGED
@@ -922,8 +922,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
922 |
|
923 |
if position_ids is None:
|
924 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
925 |
-
mask_token =
|
926 |
-
use_gmask =
|
927 |
|
928 |
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
929 |
position_ids = self.get_position_ids(
|
@@ -1085,8 +1085,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1085 |
) -> dict:
|
1086 |
batch_size, seq_length = input_ids.shape
|
1087 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
1088 |
-
mask_token =
|
1089 |
-
use_gmask =
|
1090 |
seqs = input_ids.tolist()
|
1091 |
mask_positions = [seq.index(mask_token) for seq in seqs]
|
1092 |
|
|
|
922 |
|
923 |
if position_ids is None:
|
924 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
925 |
+
mask_token = gMASK if gMASK in input_ids else MASK
|
926 |
+
use_gmask = True if gMASK in input_ids else False
|
927 |
|
928 |
mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
|
929 |
position_ids = self.get_position_ids(
|
|
|
1085 |
) -> dict:
|
1086 |
batch_size, seq_length = input_ids.shape
|
1087 |
MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
|
1088 |
+
mask_token = gMASK if gMASK in input_ids else MASK
|
1089 |
+
use_gmask = True if gMASK in input_ids else False
|
1090 |
seqs = input_ids.tolist()
|
1091 |
mask_positions = [seq.index(mask_token) for seq in seqs]
|
1092 |
|