persist dequantized model
#1
by
nudelbrot
- opened
Great work done here! Can we somehow reduce time to first token by savig the result of the (1min +) dequantize step?
Hi, thank you! Correct me if I am wrong, are you referring to the first prediction which takes some time with the PYTORCH_COMPILE backend? That first step is slow because Pytorch is compiling the function that does the dequantization + matmul step so it can run faster. As far as I know, I don't know if it's possible to cache that. Instead you can switch to the PYTORCH backend (HQQLinear.set_backend(HQQBackend.PYTORCH) ), it doesn't require that compilation step but it is a bit slower. You can easily switch between both, you don't have to reload/restart your session. I hope it helps!