RuntimeError: shape '[1, 83, 5120]' is invalid for input of size 339968

#6
by Chat-Error - opened

Error when trying to finetune this model.
Code is here:
https://gist.github.com/Kimiko-AI/681e7e42ba0b241b2f8829acff60aa31
Traceback (most recent call last):
File "/root/trl/train2.py", line 60, in
trainer.train()
File "/root/trl/trl/trainer/sft_trainer.py", line 451, in train
output = super().train(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1938, in train
return inner_training_loop(
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2279, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3318, in training_step
loss = self.compute_loss(model, inputs)
File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3363, in compute_loss
outputs = model(**inputs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 807, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py", line 1026, in forward
outputs = self.model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py", line 790, in forward
layer_outputs = self._gradient_checkpointing_func(
File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 451, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 230, in forward
outputs = run_function(*args)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py", line 542, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py", line 390, in forward
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
RuntimeError: shape '[1, 83, 5120]' is invalid for input of size 339968

Hi there! Thanks for point this out - this should be easy to fix. Let me make a PR.

- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()

hello @Xenova , i ran into this error now;

File /opt/conda/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:617, in MistralSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
614 key_states = self.k_proj(hidden_states)
615 value_states = self.v_proj(hidden_states)
--> 617 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
618 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
619 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

RuntimeError: shape '[1, 160, 32, 160]' is invalid for input of size 655360

i thought updating the transformers library would solve this since you merged your changes but it is still appearing, can you help please?

hello @Xenova , i ran into this error now;

File /opt/conda/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:617, in MistralSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
614 key_states = self.k_proj(hidden_states)
615 value_states = self.v_proj(hidden_states)
--> 617 query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
618 key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
619 value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

RuntimeError: shape '[1, 160, 32, 160]' is invalid for input of size 655360

i thought updating the transformers library would solve this since you merged your changes but it is still appearing, can you help please?

I didn't update the transformers library, and I tried to use the

- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()

and I ran into the same problem as you did here.

If the error is occurring on line 617 (which doesn't match the current version), then are you sure you're running the latest version of transformers? The latest version, released yesterday, is v4.43.1.

Sign up or log in to comment