Optimum documentation

Optimization

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Optimization

Optimum Intel can be used to apply popular compression techniques such as quantization, pruning and knowledge distillation.

Post-training optimization

Post-training compression techniques such as dynamic and static quantization can be easily applied on your model using our INCQuantizer. Note that quantization is currently only supported for CPUs (only CPU backends are available), so we will not be utilizing GPUs / CUDA in the following examples.

Dynamic quantization

You can easily add dynamic quantization on your model by using the following command line:

optimum-cli inc quantize --model distilbert-base-cased-distilled-squad --output quantized_distilbert

When applying post-training quantization, an accuracy tolerance along with an adapted evaluation function can also be specified in order to find a quantized model meeting the specified constraints. This can be done for both dynamic and static quantization.

import evaluate
from optimum.intel import INCQuantizer
from datasets import load_dataset
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
from neural_compressor.config import AccuracyCriterion, TuningCriterion, PostTrainingQuantConfig

model_name = "distilbert-base-cased-distilled-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
eval_dataset = load_dataset("squad", split="validation").select(range(64))
task_evaluator = evaluate.evaluator("question-answering")
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)

def eval_fn(model):
    qa_pipeline.model = model
    metrics = task_evaluator.compute(model_or_pipeline=qa_pipeline, data=eval_dataset, metric="squad")
    return metrics["f1"]

# Set the accepted accuracy loss to 5%
accuracy_criterion = AccuracyCriterion(tolerable_loss=0.05)
# Set the maximum number of trials to 10
tuning_criterion = TuningCriterion(max_trials=10)
quantization_config = PostTrainingQuantConfig(
    approach="dynamic", accuracy_criterion=accuracy_criterion, tuning_criterion=tuning_criterion
)
quantizer = INCQuantizer.from_pretrained(model, eval_fn=eval_fn)
quantizer.quantize(quantization_config=quantization_config, save_directory="dynamic_quantization")

Static quantization

In the same manner we can apply static quantization, for which we also need to generate the calibration dataset in order to perform the calibration step.

from functools import partial
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from neural_compressor.config import PostTrainingQuantConfig
from optimum.intel import INCQuantizer

model_name = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# The directory where the quantized model will be saved
save_dir = "static_quantization"

def preprocess_function(examples, tokenizer):
    return tokenizer(examples["sentence"], padding="max_length", max_length=128, truncation=True)

# Load the quantization configuration detailing the quantization we wish to apply
quantization_config = PostTrainingQuantConfig(approach="static")
quantizer = INCQuantizer.from_pretrained(model)
# Generate the calibration dataset needed for the calibration step
calibration_dataset = quantizer.get_calibration_dataset(
    "glue",
    dataset_config_name="sst2",
    preprocess_function=partial(preprocess_function, tokenizer=tokenizer),
    num_samples=100,
    dataset_split="train",
)
quantizer = INCQuantizer.from_pretrained(model)
# Apply static quantization and save the resulting model
quantizer.quantize(
    quantization_config=quantization_config,
    calibration_dataset=calibration_dataset,
    save_directory=save_dir,
)

Specify Quantization Recipes

The SmoothQuant methodology is available for post-training quantization. This methodology usually improves the accuracy of the model in comparison to other post-training static quantization methodologies. This is done by migrating the difficulty from activations to weights with a mathematically equivalent transformation.

- quantization_config = PostTrainingQuantConfig(approach="static")
+ recipes={"smooth_quant": True,  "smooth_quant_args": {"alpha": 0.5, "folding": True}}
+ quantization_config = PostTrainingQuantConfig(approach="static", backend="ipex", recipes=recipes)

Please refer to INC documentation and the list of models quantized with the methodology for more details.

Distributed Acuracy-aware Tuning

One challenge in model quantization is identifying the optimal configuration that balances accuracy and performance. Distributed tuning speeds up this time-consuming process by parallelizing it across multiple nodes, which accelerates the tuning process in linear scaling.

To utilize distributed tuning, please set the quant_level to 1 and run it with mpirun.

- quantization_config = PostTrainingQuantConfig(approach="static")
+ quantization_config = PostTrainingQuantConfig(approach="static", quant_level=1)
mpirun -np <number_of_processes> <RUN_CMD>

Please refer to INC documentation and text-classification example for more details.

During training optimization

The INCTrainer class provides an API to train your model while combining different compression techniques such as knowledge distillation, pruning and quantization. The INCTrainer is very similar to the 🤗 Transformers Trainer, which can be replaced with minimal changes in your code.

Quantization

To apply quantization during training, you only need to create the appropriate configuration and pass it to the INCTrainer.

  import evaluate
  import numpy as np
  from datasets import load_dataset
  from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, default_data_collator
- from transformers import Trainer
+ from optimum.intel import INCModelForSequenceClassification, INCTrainer
+ from neural_compressor import QuantizationAwareTrainingConfig

  model_id = "distilbert-base-uncased-finetuned-sst-2-english"
  model = AutoModelForSequenceClassification.from_pretrained(model_id)
  tokenizer = AutoTokenizer.from_pretrained(model_id)
  dataset = load_dataset("glue", "sst2")
  dataset = dataset.map(lambda examples: tokenizer(examples["sentence"], padding=True, max_length=128), batched=True)
  metric = evaluate.load("glue", "sst2")
  compute_metrics = lambda p: metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

  # The directory where the quantized model will be saved
  save_dir = "quantized_model"

  # The configuration detailing the quantization process
+ quantization_config = QuantizationAwareTrainingConfig()

- trainer = Trainer(
+ trainer = INCTrainer(
      model=model,
+     quantization_config=quantization_config,
      args=TrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=False),
      train_dataset=dataset["train"].select(range(300)),
      eval_dataset=dataset["validation"],
      compute_metrics=compute_metrics,
      tokenizer=tokenizer,
      data_collator=default_data_collator,
  )

  train_result = trainer.train()
  metrics = trainer.evaluate()
  trainer.save_model()

- model = AutoModelForSequenceClassification.from_pretrained(save_dir)
+ model = INCModelForSequenceClassification.from_pretrained(save_dir)

Pruning

In the same manner, pruning can be applied by specifying the pruning configuration detailing the desired pruning process. To know more about the different supported methodologies, you can refer to the Neural Compressor documentation. At the moment, pruning is applied on both the linear and the convolutional layers, and not on other layers such as the embeddings. It’s important to mention that the pruning sparsity defined in the configuration will be applied on these layers, and thus will not results in the global model sparsity.

- from transformers import Trainer
+ from optimum.intel import INCTrainer
+ from neural_compressor import WeightPruningConfig

  # The configuration detailing the pruning process
+ pruning_config = WeightPruningConfig(
+     pruning_type="magnitude",
+     start_step=0,
+     end_step=15,
+     target_sparsity=0.2,
+     pruning_scope="local",
+ )

- trainer = Trainer(
+ trainer = INCTrainer(
      model=model,
+     pruning_config=pruning_config,
      args=TrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=False),
      train_dataset=dataset["train"].select(range(300)),
      eval_dataset=dataset["validation"],
      compute_metrics=compute_metrics,
      tokenizer=tokenizer,
      data_collator=default_data_collator,
  )

  train_result = trainer.train()
  metrics = trainer.evaluate()
  trainer.save_model()

  model = AutoModelForSequenceClassification.from_pretrained(save_dir)

Knowledge distillation

Knowledge distillation can also be applied in the same manner. To know more about the different supported methodologies, you can refer to the Neural Compressor documentation

- from transformers import Trainer
+ from optimum.intel import INCTrainer
+ from neural_compressor import DistillationConfig

+ teacher_model_id = "textattack/bert-base-uncased-SST-2"
+ teacher_model = AutoModelForSequenceClassification.from_pretrained(teacher_model_id)
+ distillation_config = DistillationConfig(teacher_model=teacher_model)

- trainer = Trainer(
+ trainer = INCTrainer(
      model=model,
+     distillation_config=distillation_config,
      args=TrainingArguments(save_dir, num_train_epochs=1.0, do_train=True, do_eval=False),
      train_dataset=dataset["train"].select(range(300)),
      eval_dataset=dataset["validation"],
      compute_metrics=compute_metrics,
      tokenizer=tokenizer,
      data_collator=default_data_collator,
  )

  train_result = trainer.train()
  metrics = trainer.evaluate()
  trainer.save_model()

  model = AutoModelForSequenceClassification.from_pretrained(save_dir)

Loading a quantized model

To load a quantized model hosted locally or on the 🤗 hub, you must instantiate you model using our INCModelForXxx classes.

from optimum.intel import INCModelForSequenceClassification

model_name = "Intel/distilbert-base-uncased-finetuned-sst-2-english-int8-dynamic"
model = INCModelForSequenceClassification.from_pretrained(model_name)

You can load many more quantized models hosted on the hub under the Intel organization here.

Inference with Transformers pipeline

The quantized model can then easily be used to run inference with the Transformers pipelines.

from transformers import AutoTokenizer, pipeline

tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe_cls = pipeline("text-classification", model=model, tokenizer=tokenizer)
text = "He's a dreadful magician."
outputs = pipe_cls(text)

[{'label': 'NEGATIVE', 'score': 0.9880216121673584}]

Check out the examples directory for more sophisticated usage.

< > Update on GitHub