Run with Transformers got Error: Tensor on device meta is not on the expected device cuda:0
Followed the instructions to run AI21-Jamba-1.5-Mini model with transformers lib. Got the following error:
Traceback (most recent call last):
File "/home/dell/taoz/jamba_tf.py", line 19, in <module>
outputs = model.generate(input_ids, max_new_tokens=216)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/transformers/generation/utils.py", line 2048, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/transformers/generation/utils.py", line 3008, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/torch/library.py", line 788, in inner
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 471, in fake_impl
return self._abstract_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/torch/_prims/__init__.py", line 383, in _prim_elementwise_meta
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
File "/home/dell/taoz/jamba_test_env/lib/python3.11/site-packages/torch/_prims_common/__init__.py", line 742, in check_same_device
raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device cuda:0!
Hardware: R760 with 4xA100
OS: Ubuntu 22.04
Python Version: 3.11.10
vLLM version: 0.6.2
transformers version: 4.45.1
flash-attn version: 2.6.3
hey @taozhang9527
From the error message it seems like the input_ids
tensor and the model are not on the same device. Moreover, it seems that one of them is on the meta device.
In order to run Jamba successfully, make sure the model is loaded to the GPU, and that the input_ids
tensor is on the same device as the model, by running:
input_ids = input_ids.to(model.device)
I did try the following options regarding input_ids
, but both failed with the same error.
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device)
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to("cuda:0")
This is related to #13 as well. After install python3.11-dev, the error is gone. We can close the issue.