Deepspeed ZeRO3 Compatible Issue

#4
by BK-Lee - opened

I've faced the issue of compatiblity with DeepSpeed ZeRO3
Could you suggest a solution for it?

class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(self, config: SiglipVisionConfig):
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) # this is the problem
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(config)

    def forward(self, hidden_state):
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0] # this is the problem [The point Error Occrured!]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
        hidden_state = residual + self.mlp(hidden_state)

        return hidden_state[:, 0]
[rank3]:   File "lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1275, in forward
[rank3]:     attn_output, attn_output_weights = F.multi_head_attention_forward(
[rank3]:                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "lib/python3.11/site-packages/torch/nn/functional.py", line 5533, in multi_head_attention_forward
[rank3]:     attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
[rank3]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py", line 118, in zero3_linear_wrap
[rank3]:     return LinearFunctionForZeroStage3.apply(input, weight, bias)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
[rank3]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 455, in decorate_fwd
[rank3]:     return fwd(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py", line 62, in forward
[rank3]:     ret = torch.addmm(bias, input, weight.t())
[rank3]: RuntimeError: mat2 must be a matrix, got 1-D tensor

Sign up or log in to comment