training: fix type mismatch when training
#6
by
Jack477
- opened
No description provided.
Jack477
changed pull request status to
open
Jack477
changed pull request title from
fix type mismatch when training
to training: fix type mismatch when training
error stack when using fp16 training :
File "/root/.cache/huggingface/modules/transformers_modules/modeling_deepseek.py", line 1252, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1570, in _call_impl
result = forward_call(*args, **kwargs)
File "/root/.cache/huggingface/modules/transformers_modules/modeling_deepseek.py", line 821, in forward
q = self.q_proj(hidden_states)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1570, in _call_impl
result = forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 109, in zero3_linear_wrap
return LinearFunctionForZeroStage3.apply(input, weight)
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.8/dist-packages/torch/cuda/amp/autocast_mode.py", line 98, in decorate_fwd
return fwd(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/zero/linear.py", line 57, in forward
output = input.matmul(weight.t())
RuntimeError: expected scalar type Float but found Half
luofuli
changed pull request status to
merged