Large value difference when comparing hidden_states with flash attention ON and OFF

#42
by Ye27 - opened

Hi there,
I run two times of the same inference script, with the following settings:

    "HuggingFaceM4/idefics2-8b/",
   _attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16).to('cuda')

model = Idefics2ForConditionalGeneration.from_pretrained(
    "HuggingFaceM4/idefics2-8b/",
    torch_dtype=torch.bfloat16).to('cuda')

I set a debug breakpoint to save some internal states, however, when I manually check their value differences, they are much larger than expected.

Here are some comparisons:

torch.isclose(flash['last_hidden_state'], noflash['last_hidden_state'])
tensor([[[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ...,  True, False,  True],
         [False, False, False,  ..., False, False, False]]], device='cuda:0')

A careful inspection on the image_hidden_states leads to the following results:

flash['last_hidden_state']
tensor([[[-0.6992,  1.4766, -1.8125,  ..., -1.0859, -1.1562,  1.7188],
         [-7.6250,  5.0625,  5.5312,  ...,  8.3750,  0.8516,  0.3438],
         [ 1.6016,  0.1387,  5.2500,  ..., -5.8125,  1.2969, -2.9688],
         ...,
         [-2.9375, -4.9688, -8.4375,  ...,  5.4062,  3.9688, -0.3887],
         [-1.1562, -0.1416, -0.7148,  ...,  2.2969, -0.3926, -4.3750],
         [-1.0156, -1.8906, -3.9062,  ...,  0.2266,  0.6523, -2.7656]]],
       device='cuda:0', dtype=torch.bfloat16)
>>> flash['last_hidden_state'].shape
torch.Size([1, 150, 4096])
>>> flash['image_hidden_states']
tensor([[[ 0.6914,  1.7031,  0.7227,  ...,  0.0608, -0.1758, -0.3613],
         [-0.6523,  1.0234,  0.6094,  ...,  0.2598,  1.0469,  0.5195],
         [-1.0312,  1.4453,  0.4883,  ...,  0.7852, -0.6914, -0.4316],
         ...,
         [-0.8164,  1.1953, -0.3066,  ...,  1.0078,  0.2012, -0.5938],
         [ 0.5898, -0.5664,  0.6836,  ..., -1.6094,  0.5117, -0.9766],
         [ 0.3789,  0.9414, -0.1309,  ...,  0.9922,  0.1963,  0.4551]],

        [[ 0.6914,  1.7031,  0.7227,  ...,  0.0608, -0.1758, -0.3613],
         [-0.6523,  1.0234,  0.6094,  ...,  0.2598,  1.0469,  0.5195],
         [-1.0312,  1.4453,  0.4883,  ...,  0.7852, -0.6914, -0.4316],
         ...,
         [-0.8164,  1.1953, -0.3066,  ...,  1.0078,  0.2012, -0.5938],
         [ 0.5898, -0.5664,  0.6836,  ..., -1.6094,  0.5117, -0.9766],
         [ 0.3789,  0.9414, -0.1309,  ...,  0.9922,  0.1963,  0.4551]]],
       device='cuda:0', dtype=torch.bfloat16)
>>> noflash['image_hidden_states']
tensor([[[ 0.6875,  1.7109,  0.7109,  ...,  0.0562, -0.1709, -0.3750],
         [-0.6719,  1.0156,  0.6211,  ...,  0.2578,  1.0234,  0.5156],
         [-1.0312,  1.4453,  0.4844,  ...,  0.7852, -0.6680, -0.4512],
         ...,
         [-0.7969,  1.2109, -0.3203,  ...,  1.0391,  0.2246, -0.6016],
         [ 0.6172, -0.5742,  0.7227,  ..., -1.5938,  0.5195, -0.9844],
         [ 0.3809,  0.9375, -0.1357,  ...,  1.0078,  0.1973,  0.4492]],

        [[ 0.6875,  1.7109,  0.7109,  ...,  0.0562, -0.1709, -0.3750],
         [-0.6719,  1.0156,  0.6211,  ...,  0.2578,  1.0234,  0.5156],
         [-1.0312,  1.4453,  0.4844,  ...,  0.7852, -0.6680, -0.4512],
         ...,
         [-0.7969,  1.2109, -0.3203,  ...,  1.0391,  0.2246, -0.6016],
         [ 0.6172, -0.5742,  0.7227,  ..., -1.5938,  0.5195, -0.9844],
         [ 0.3809,  0.9375, -0.1357,  ...,  1.0078,  0.1973,  0.4492]]],
       device='cuda:0', dtype=torch.bfloat16)

Is this difference expected?

Sign up or log in to comment