Large value difference when comparing hidden_states with flash attention ON and OFF
#42
by
Ye27
- opened
Hi there,
I run two times of the same inference script, with the following settings:
"HuggingFaceM4/idefics2-8b/",
_attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16).to('cuda')
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b/",
torch_dtype=torch.bfloat16).to('cuda')
I set a debug breakpoint to save some internal states, however, when I manually check their value differences, they are much larger than expected.
Here are some comparisons:
torch.isclose(flash['last_hidden_state'], noflash['last_hidden_state'])
tensor([[[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., True, False, True],
[False, False, False, ..., False, False, False]]], device='cuda:0')
A careful inspection on the image_hidden_states leads to the following results:
flash['last_hidden_state']
tensor([[[-0.6992, 1.4766, -1.8125, ..., -1.0859, -1.1562, 1.7188],
[-7.6250, 5.0625, 5.5312, ..., 8.3750, 0.8516, 0.3438],
[ 1.6016, 0.1387, 5.2500, ..., -5.8125, 1.2969, -2.9688],
...,
[-2.9375, -4.9688, -8.4375, ..., 5.4062, 3.9688, -0.3887],
[-1.1562, -0.1416, -0.7148, ..., 2.2969, -0.3926, -4.3750],
[-1.0156, -1.8906, -3.9062, ..., 0.2266, 0.6523, -2.7656]]],
device='cuda:0', dtype=torch.bfloat16)
>>> flash['last_hidden_state'].shape
torch.Size([1, 150, 4096])
>>> flash['image_hidden_states']
tensor([[[ 0.6914, 1.7031, 0.7227, ..., 0.0608, -0.1758, -0.3613],
[-0.6523, 1.0234, 0.6094, ..., 0.2598, 1.0469, 0.5195],
[-1.0312, 1.4453, 0.4883, ..., 0.7852, -0.6914, -0.4316],
...,
[-0.8164, 1.1953, -0.3066, ..., 1.0078, 0.2012, -0.5938],
[ 0.5898, -0.5664, 0.6836, ..., -1.6094, 0.5117, -0.9766],
[ 0.3789, 0.9414, -0.1309, ..., 0.9922, 0.1963, 0.4551]],
[[ 0.6914, 1.7031, 0.7227, ..., 0.0608, -0.1758, -0.3613],
[-0.6523, 1.0234, 0.6094, ..., 0.2598, 1.0469, 0.5195],
[-1.0312, 1.4453, 0.4883, ..., 0.7852, -0.6914, -0.4316],
...,
[-0.8164, 1.1953, -0.3066, ..., 1.0078, 0.2012, -0.5938],
[ 0.5898, -0.5664, 0.6836, ..., -1.6094, 0.5117, -0.9766],
[ 0.3789, 0.9414, -0.1309, ..., 0.9922, 0.1963, 0.4551]]],
device='cuda:0', dtype=torch.bfloat16)
>>> noflash['image_hidden_states']
tensor([[[ 0.6875, 1.7109, 0.7109, ..., 0.0562, -0.1709, -0.3750],
[-0.6719, 1.0156, 0.6211, ..., 0.2578, 1.0234, 0.5156],
[-1.0312, 1.4453, 0.4844, ..., 0.7852, -0.6680, -0.4512],
...,
[-0.7969, 1.2109, -0.3203, ..., 1.0391, 0.2246, -0.6016],
[ 0.6172, -0.5742, 0.7227, ..., -1.5938, 0.5195, -0.9844],
[ 0.3809, 0.9375, -0.1357, ..., 1.0078, 0.1973, 0.4492]],
[[ 0.6875, 1.7109, 0.7109, ..., 0.0562, -0.1709, -0.3750],
[-0.6719, 1.0156, 0.6211, ..., 0.2578, 1.0234, 0.5156],
[-1.0312, 1.4453, 0.4844, ..., 0.7852, -0.6680, -0.4512],
...,
[-0.7969, 1.2109, -0.3203, ..., 1.0391, 0.2246, -0.6016],
[ 0.6172, -0.5742, 0.7227, ..., -1.5938, 0.5195, -0.9844],
[ 0.3809, 0.9375, -0.1357, ..., 1.0078, 0.1973, 0.4492]]],
device='cuda:0', dtype=torch.bfloat16)
Is this difference expected?