Optimize inference using torch.compile()
This guide aims to provide a benchmark on the inference speed-ups introduced with torch.compile()
for computer vision models in 🤗 Transformers.
Benefits of torch.compile
Depending on the model and the GPU, torch.compile()
yields up to 30% speed-up during inference. To use torch.compile()
, simply install any version of torch
above 2.0.
Compiling a model takes time, so it’s useful if you are compiling the model only once instead of every time you infer.
To compile any computer vision model of your choice, call torch.compile()
on the model as shown below:
from transformers import AutoModelForImageClassification
model = AutoModelForImageClassification.from_pretrained(MODEL_ID).to(DEVICE)
+ model = torch.compile(model)
compile()
comes with multiple modes for compiling, which essentially differ in compilation time and inference overhead. max-autotune
takes longer than reduce-overhead
but results in faster inference. Default mode is fastest for compilation but is not as efficient compared to reduce-overhead
for inference time. In this guide, we used the default mode. You can learn more about it here.
We benchmarked torch.compile
with different computer vision models, tasks, types of hardware, and batch sizes on torch
version 2.0.1.
Benchmarking code
Below you can find the benchmarking code for each task. We warm up the GPU before inference and take the mean time of 300 inferences, using the same image each time.
Image Classification with ViT
import torch
from PIL import Image
import requests
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification
from accelerate.test_utils.testing import get_backend
device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to(device)
model = torch.compile(model)
processed_input = processor(image, return_tensors='pt').to(device)
with torch.no_grad():
_ = model(**processed_input)
Object Detection with DETR
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from accelerate.test_utils.testing import get_backend
device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = AutoModelForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
model = torch.compile(model)
texts = ["a photo of a cat", "a photo of a dog"]
inputs = processor(text=texts, images=image, return_tensors="pt").to(device)
with torch.no_grad():
_ = model(**inputs)
Image Segmentation with Segformer
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from accelerate.test_utils.testing import get_backend
device, _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(device)
model = torch.compile(model)
seg_inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
_ = model(**seg_inputs)
Below you can find the list of the models we benchmarked.
Image Classification
- google/vit-base-patch16-224
- microsoft/beit-base-patch16-224-pt22k-ft22k
- facebook/convnext-large-224
- microsoft/resnet-50
Image Segmentation
- nvidia/segformer-b0-finetuned-ade-512-512
- facebook/mask2former-swin-tiny-coco-panoptic
- facebook/maskformer-swin-base-ade
- google/deeplabv3_mobilenet_v2_1.0_513
Object Detection
Below you can find visualization of inference durations with and without torch.compile()
and percentage improvements for each model in different hardware and batch sizes.
Below you can find inference durations in milliseconds for each model with and without compile()
. Note that OwlViT results in OOM in larger batch sizes.
A100 (batch size: 1)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 9.325 | 7.584 |
Image Segmentation/Segformer | 11.759 | 10.500 |
Object Detection/OwlViT | 24.978 | 18.420 |
Image Classification/BeiT | 11.282 | 8.448 |
Object Detection/DETR | 34.619 | 19.040 |
Image Classification/ConvNeXT | 10.410 | 10.208 |
Image Classification/ResNet | 6.531 | 4.124 |
Image Segmentation/Mask2former | 60.188 | 49.117 |
Image Segmentation/Maskformer | 75.764 | 59.487 |
Image Segmentation/MobileNet | 8.583 | 3.974 |
Object Detection/Resnet-101 | 36.276 | 18.197 |
Object Detection/Conditional-DETR | 31.219 | 17.993 |
A100 (batch size: 4)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 14.832 | 14.499 |
Image Segmentation/Segformer | 18.838 | 16.476 |
Image Classification/BeiT | 13.205 | 13.048 |
Object Detection/DETR | 48.657 | 32.418 |
Image Classification/ConvNeXT | 22.940 | 21.631 |
Image Classification/ResNet | 6.657 | 4.268 |
Image Segmentation/Mask2former | 74.277 | 61.781 |
Image Segmentation/Maskformer | 180.700 | 159.116 |
Image Segmentation/MobileNet | 14.174 | 8.515 |
Object Detection/Resnet-101 | 68.101 | 44.998 |
Object Detection/Conditional-DETR | 56.470 | 35.552 |
A100 (batch size: 16)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 40.944 | 40.010 |
Image Segmentation/Segformer | 37.005 | 31.144 |
Image Classification/BeiT | 41.854 | 41.048 |
Object Detection/DETR | 164.382 | 161.902 |
Image Classification/ConvNeXT | 82.258 | 75.561 |
Image Classification/ResNet | 7.018 | 5.024 |
Image Segmentation/Mask2former | 178.945 | 154.814 |
Image Segmentation/Maskformer | 638.570 | 579.826 |
Image Segmentation/MobileNet | 51.693 | 30.310 |
Object Detection/Resnet-101 | 232.887 | 155.021 |
Object Detection/Conditional-DETR | 180.491 | 124.032 |
V100 (batch size: 1)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 10.495 | 6.00 |
Image Segmentation/Segformer | 13.321 | 5.862 |
Object Detection/OwlViT | 25.769 | 22.395 |
Image Classification/BeiT | 11.347 | 7.234 |
Object Detection/DETR | 33.951 | 19.388 |
Image Classification/ConvNeXT | 11.623 | 10.412 |
Image Classification/ResNet | 6.484 | 3.820 |
Image Segmentation/Mask2former | 64.640 | 49.873 |
Image Segmentation/Maskformer | 95.532 | 72.207 |
Image Segmentation/MobileNet | 9.217 | 4.753 |
Object Detection/Resnet-101 | 52.818 | 28.367 |
Object Detection/Conditional-DETR | 39.512 | 20.816 |
V100 (batch size: 4)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 15.181 | 14.501 |
Image Segmentation/Segformer | 16.787 | 16.188 |
Image Classification/BeiT | 15.171 | 14.753 |
Object Detection/DETR | 88.529 | 64.195 |
Image Classification/ConvNeXT | 29.574 | 27.085 |
Image Classification/ResNet | 6.109 | 4.731 |
Image Segmentation/Mask2former | 90.402 | 76.926 |
Image Segmentation/Maskformer | 234.261 | 205.456 |
Image Segmentation/MobileNet | 24.623 | 14.816 |
Object Detection/Resnet-101 | 134.672 | 101.304 |
Object Detection/Conditional-DETR | 97.464 | 69.739 |
V100 (batch size: 16)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 52.209 | 51.633 |
Image Segmentation/Segformer | 61.013 | 55.499 |
Image Classification/BeiT | 53.938 | 53.581 |
Object Detection/DETR | OOM | OOM |
Image Classification/ConvNeXT | 109.682 | 100.771 |
Image Classification/ResNet | 14.857 | 12.089 |
Image Segmentation/Mask2former | 249.605 | 222.801 |
Image Segmentation/Maskformer | 831.142 | 743.645 |
Image Segmentation/MobileNet | 93.129 | 55.365 |
Object Detection/Resnet-101 | 482.425 | 361.843 |
Object Detection/Conditional-DETR | 344.661 | 255.298 |
T4 (batch size: 1)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 16.520 | 15.786 |
Image Segmentation/Segformer | 16.116 | 14.205 |
Object Detection/OwlViT | 53.634 | 51.105 |
Image Classification/BeiT | 16.464 | 15.710 |
Object Detection/DETR | 73.100 | 53.99 |
Image Classification/ConvNeXT | 32.932 | 30.845 |
Image Classification/ResNet | 6.031 | 4.321 |
Image Segmentation/Mask2former | 79.192 | 66.815 |
Image Segmentation/Maskformer | 200.026 | 188.268 |
Image Segmentation/MobileNet | 18.908 | 11.997 |
Object Detection/Resnet-101 | 106.622 | 82.566 |
Object Detection/Conditional-DETR | 77.594 | 56.984 |
T4 (batch size: 4)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 43.653 | 43.626 |
Image Segmentation/Segformer | 45.327 | 42.445 |
Image Classification/BeiT | 52.007 | 51.354 |
Object Detection/DETR | 277.850 | 268.003 |
Image Classification/ConvNeXT | 119.259 | 105.580 |
Image Classification/ResNet | 13.039 | 11.388 |
Image Segmentation/Mask2former | 201.540 | 184.670 |
Image Segmentation/Maskformer | 764.052 | 711.280 |
Image Segmentation/MobileNet | 74.289 | 48.677 |
Object Detection/Resnet-101 | 421.859 | 357.614 |
Object Detection/Conditional-DETR | 289.002 | 226.945 |
T4 (batch size: 16)
Task/Model | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|
Image Classification/ViT | 163.914 | 160.907 |
Image Segmentation/Segformer | 192.412 | 163.620 |
Image Classification/BeiT | 188.978 | 187.976 |
Object Detection/DETR | OOM | OOM |
Image Classification/ConvNeXT | 422.886 | 388.078 |
Image Classification/ResNet | 44.114 | 37.604 |
Image Segmentation/Mask2former | 756.337 | 695.291 |
Image Segmentation/Maskformer | 2842.940 | 2656.88 |
Image Segmentation/MobileNet | 299.003 | 201.942 |
Object Detection/Resnet-101 | 1619.505 | 1262.758 |
Object Detection/Conditional-DETR | 1137.513 | 897.390 |
PyTorch Nightly
We also benchmarked on PyTorch nightly (2.1.0dev, find the wheel here) and observed improvement in latency both for uncompiled and compiled models.
A100
Task/Model | Batch Size | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|---|
Image Classification/BeiT | Unbatched | 12.462 | 6.954 |
Image Classification/BeiT | 4 | 14.109 | 12.851 |
Image Classification/BeiT | 16 | 42.179 | 42.147 |
Object Detection/DETR | Unbatched | 30.484 | 15.221 |
Object Detection/DETR | 4 | 46.816 | 30.942 |
Object Detection/DETR | 16 | 163.749 | 163.706 |
T4
Task/Model | Batch Size | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|---|
Image Classification/BeiT | Unbatched | 14.408 | 14.052 |
Image Classification/BeiT | 4 | 47.381 | 46.604 |
Image Classification/BeiT | 16 | 42.179 | 42.147 |
Object Detection/DETR | Unbatched | 68.382 | 53.481 |
Object Detection/DETR | 4 | 269.615 | 204.785 |
Object Detection/DETR | 16 | OOM | OOM |
V100
Task/Model | Batch Size | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|---|
Image Classification/BeiT | Unbatched | 13.477 | 7.926 |
Image Classification/BeiT | 4 | 15.103 | 14.378 |
Image Classification/BeiT | 16 | 52.517 | 51.691 |
Object Detection/DETR | Unbatched | 28.706 | 19.077 |
Object Detection/DETR | 4 | 88.402 | 62.949 |
Object Detection/DETR | 16 | OOM | OOM |
Reduce Overhead
We benchmarked reduce-overhead
compilation mode for A100 and T4 in Nightly.
A100
Task/Model | Batch Size | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|---|
Image Classification/ConvNeXT | Unbatched | 11.758 | 7.335 |
Image Classification/ConvNeXT | 4 | 23.171 | 21.490 |
Image Classification/ResNet | Unbatched | 7.435 | 3.801 |
Image Classification/ResNet | 4 | 7.261 | 2.187 |
Object Detection/Conditional-DETR | Unbatched | 32.823 | 11.627 |
Object Detection/Conditional-DETR | 4 | 50.622 | 33.831 |
Image Segmentation/MobileNet | Unbatched | 9.869 | 4.244 |
Image Segmentation/MobileNet | 4 | 14.385 | 7.946 |
T4
Task/Model | Batch Size | torch 2.0 - no compile | torch 2.0 - compile |
---|---|---|---|
Image Classification/ConvNeXT | Unbatched | 32.137 | 31.84 |
Image Classification/ConvNeXT | 4 | 120.944 | 110.209 |
Image Classification/ResNet | Unbatched | 9.761 | 7.698 |
Image Classification/ResNet | 4 | 15.215 | 13.871 |
Object Detection/Conditional-DETR | Unbatched | 72.150 | 57.660 |
Object Detection/Conditional-DETR | 4 | 301.494 | 247.543 |
Image Segmentation/MobileNet | Unbatched | 22.266 | 19.339 |
Image Segmentation/MobileNet | 4 | 78.311 | 50.983 |