About the unusual attention_mask of ChatGLM

#40
by hiyouga - opened

Hi,

Thanks for your wonderful work.

I found the attention mask of ChatGLM uses 1 to indicate the indices to be masked and 0 to indicate the indices not to be masked, which differs from Huggingface's implementation (see [1]), which use 1 for tokens that are not masked. Although it depends on different implementations, the attention mask of ChatGLM may cause unexpected problems. For example, it is incompatible with the Prompt-Tuning and P-Tuning methods provided by the Huggingface's PEFT library (see [2]). I wonder is there a plan to fix this?

Looking forward to your reply.

Sincerely.

[1] https://github.com/huggingface/transformers/blob/5a71977b8b95d39834f07a1f739305e354bc05d0/src/transformers/models/bert/modeling_bert.py#L828
[2] https://github.com/huggingface/peft/blob/cc82b674b5db38b9a393463d38afe66e8f48ac1c/src/peft/peft_model.py#L728

hiyouga changed discussion title from About attention_mask to About the unusual attention_mask of ChatGLM

I also noticed the unusual attention_mask for THUDM/chatglm-6b, here is my findings:

kwargs = {
    'max_length': 5,
    'padding': True,
    'truncation': True,
    'add_special_tokens': False,
}
text = '汉'
tokenizer(text, **kwargs)
  • ChatGLM-6B
{'input_ids': [5, 64876], 'attention_mask': array([[[False, False],
        [False, False]]]), 'position_ids': array([[0, 1],
       [0, 0]])}
  • ChatGLM2-6B
{'input_ids': [30910, 55313], 'attention_mask': [1, 1], 'position_ids': [0, 1]}
  • bert-base-chinese
{'input_ids': [3727], 'token_type_ids': [0], 'attention_mask': [1]}

False is NOT masked here, and int(False) is 0, that might be where 0 comes from.

Another thing is the shape of theattention_mask is unusual as well.

(1, 2, 2) which should be (2,)

The code which generated those attention_mask is here:

                attention_mask = np.ones((1, seq_length, seq_length))
                attention_mask = np.tril(attention_mask)
                attention_mask[:, :, :context_length] = 1
                attention_mask = np.bool_(attention_mask < 0.5)

To convert the attention_mask to the normal one, I used the following code:

attention_mask = np.where([m[0][-1] for m in attention_mask], 0, 1)

Sign up or log in to comment