Torch compile + dynamo error
Hi, trying to use your torch compile example as it is with
accelerate==0.34.2
torch== 2.4.1
transformers==4.45.1
im getting torch compile error related to dynamo backend:
raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands
from user code:
File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1765, in forward
outputs = self.model(
File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1616, in forward
encoder_outputs = self.encoder(
File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 1121, in forward
layer_outputs = encoder_layer(
File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/home/jupyter/miniconda3/envs/whisper_turbo/lib/python3.11/site-packages/transformers/models/whisper/modeling_whisper.py", line 689, in forward
if hidden_states.dtype == torch.float16 and (
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
adding
torch._dynamo.config.suppress_errors = True
is not helping
So, i'm wondering if it needs more manipulations than provided in your example to run it with torch compile?
Also confirming this issue.
torch 2.4.1
accelerate 0.34.2
transformers 4.45.1
Same here with torch 2.4.1 and accelerate 0.34.2
Same issue under:
Torch: 2.4.1+cu121
accelerate: 1.0.0
transformers: 4.45.2
You can get to work with model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=False)
but the speed-up nowhere near impressive (18.76s vs 14.45s). I used the code from https://github.com/sanchit-gandhi/notebooks/blob/main/whisper_compile.ipynb