RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'
When I use phi2 for inference in colab, I get an error: RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'.
This is my code:
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U datasets scipy ipywidgets einops
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
inputs = tokenizer('''How can I get a car?''', return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
print(text)
The error content is:
/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py in layer_norm(input, normalized_shape, weight, bias, eps)
2541 layer_norm, (input, weight, bias), input, normalized_shape, weight=weight, bias=bias, eps=eps
2542 )
-> 2543 return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
2544
2545
RuntimeError: "LayerNormKernelImpl" not implemented for 'Half'
transformers version is 4.37.0.dev0
Could you help me figure out how to fix this ? thx!!
btw I got the same error when using transformers==4.26.2
@luckychao I had the same issue, however it was on my desktop where I am running on a CPU. If you are on Colab, change your runtime to T4 GPU or any other GPU and try again. If you were planning to run it on a CPU then set torch_dtype=torch.float32 in your code.
@luckychao I had the same issue, however it was on my desktop where I am running on a CPU. If you are on Colab, change your runtime to T4 GPU or any other GPU and try again. If you were planning to run it on a CPU then set torch_dtype=torch.float32 in your code.
Thanks for the answer! I'll try