Why 4bit quantised performance is slower than fp 16?
I am wrapping my head around
Trying to understand Why A is faster than B
A.
tokenizer_large = AutoTokenizer.from_pretrained(f"google/flan-t5-large")
model_large = AutoModelForSeq2SeqLM.from_pretrained(f"google/flan-t5-large", torch_dtype=torch.float16, device_map="auto")
IS FASTER THEN
B.
model_id = "google/flan-t5-large"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False
)
model_large = AutoModelForSeq2SeqLM.from_pretrained(model_id, quantization_config=quantization_config)
tokenizer_large = AutoTokenizer.from_pretrained(model_id) (edited)
Hi @kapil1611
load_in_4bit flag activates the 4bit quantization described in this paper: https://arxiv.org/abs/2305.14314 - that method iteratively quantizes and de-quantizes linear layers in 4bit and makes the matmul computation either in float32 (default) or half precision. The quantization / de-quantization adds some overhead, making it slower in most cases compared to half-precision models.
By default we use bnb_4bit_compute_dtype=torch.float32
: https://github.com/huggingface/transformers/blob/main/src/transformers/utils/quantization_config.py#L204
For faster generation, you can benefit from the optimized kernels described here: https://twitter.com/Tim_Dettmers/status/1683118705956491264?s=20 - first make sure to use the latest stable bitsandbytes
package pip install -U bitsandbytes
, then run:
import torch
from transformers import BitsAndBytesConfig, AutoModelForSeq2SeqLM, AutoTokenizer
model_id = "google/flan-t5-large"
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.float16
)
model_large = AutoModelForSeq2SeqLM.from_pretrained(model_id, quantization_config=quantization_config)
tokenizer_large = AutoTokenizer.from_pretrained(model_id)
That should hopefully lead to much faster inference speed compared to default 4bit models, and maybe similar or faster inference speed with batch_size=1 depending on the hardware