Deepspeed ZeRO3 Compatible Issue
#4
by
BK-Lee
- opened
I've faced the issue of compatiblity with DeepSpeed ZeRO3
Could you suggest a solution for it?
class SiglipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) # this is the problem
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0] # this is the problem [The point Error Occrured!]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
[rank3]: File "lib/python3.11/site-packages/torch/nn/modules/activation.py", line 1275, in forward
[rank3]: attn_output, attn_output_weights = F.multi_head_attention_forward(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "lib/python3.11/site-packages/torch/nn/functional.py", line 5533, in multi_head_attention_forward
[rank3]: attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py", line 118, in zero3_linear_wrap
[rank3]: return LinearFunctionForZeroStage3.apply(input, weight, bias)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
[rank3]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 455, in decorate_fwd
[rank3]: return fwd(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py", line 62, in forward
[rank3]: ret = torch.addmm(bias, input, weight.t())
[rank3]: RuntimeError: mat2 must be a matrix, got 1-D tensor