Fail to run BNB 8-bit Quantization version of "OpenGVLab/InternVL2-8B" on Jetson AGX Orin
I fail to run BNB 8-bit Quantization version of "OpenGVLab/InternVL2-8B" on Jetson AGX Orin with the following error message:
Exception in thread Thread-3 (chat_with_hist):
Traceback (most recent call last):
File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
self.run()
File "/usr/lib/python3.11/threading.py", line 982, in run
self._target(*self._args, **self._kwargs)
File "/data/myllm/MultilingualE5large_2_2.py", line 130, in chat_with_hist
response, history = models[0].chat(tokenizer, None, "hello, this is james", generation_config, history=None, return_history=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-8B/3bfd3664dea4f3da628785f5125d30f889701253/modeling_internvl_chat.py", line 286, in chat
generation_output = self.generate(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-8B/3bfd3664dea4f3da628785f5125d30f889701253/modeling_internvl_chat.py", line 336, in generate
outputs = self.language_model.generate(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 2024, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py", line 2982, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-8B/3bfd3664dea4f3da628785f5125d30f889701253/modeling_internlm2.py", line 1068, in forward
outputs = self.model(
^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-8B/3bfd3664dea4f3da628785f5125d30f889701253/modeling_internlm2.py", line 953, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-8B/3bfd3664dea4f3da628785f5125d30f889701253/modeling_internlm2.py", line 656, in forward
hidden_states, self_attn_weights, present_key_value = self.attention(
^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/accelerate/hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-8B/3bfd3664dea4f3da628785f5125d30f889701253/modeling_internlm2.py", line 498, in forward
attn_output = self._flash_attention_forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/OpenGVLab/InternVL2-8B/3bfd3664dea4f3da628785f5125d30f889701253/modeling_internlm2.py", line 557, in _flash_attention_forward
attn_output = flash_attn_func(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flash_attn/flash_attn_interface.py", line 880, in flash_attn_func
return FlashAttnFunc.apply(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 574, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flash_attn/flash_attn_interface.py", line 546, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/dist-packages/flash_attn/flash_attn_interface.py", line 52, in _flash_attn_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid device function
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA
to enable device-side assertions.
===========================================================
Following is my env configuration:
Module: NVIDIA Jetson AGX Orin (64GB ram)
OS: Ubuntu 20.04 Focal Fossa
CUDA Arch BIN: 8.7
L4T: 35.3.1
Jetpack: 5.1.1
CUDA: 11.8.89
cuDNN: 8.6.0.166
Python: 3.11.9
torch: 2.4.0
torchvision: 0.19.0
flash_attn: 2.6.3
any ideas?
It looks like the CUDA kernel function is not compatible with the device architecture. Unfortunately we cannot reproduce the bug because we do not have a Jetson AGX Orin device. Please verify that flash_attn is compiled successfully and supports NVIDIA Jetson AGX Orin.