|
import torch |
|
from torch.ao.quantization import quantize_dynamic |
|
from optimum.fx.optimization import Transformation |
|
from transformers import AutoModel, AutoTokenizer |
|
from transformers.utils.fx import symbolic_trace |
|
|
|
|
|
class DynamicQuantization(Transformation): |
|
def __init__(self, dtype=torch.qint8, qconfig_spec=None, mapping=None): |
|
super().__init__() |
|
self.dtype = dtype |
|
self.qconfig_spec = qconfig_spec |
|
self.mapping = mapping |
|
|
|
def transform(self, graph_module): |
|
|
|
quantized_module = quantize_dynamic( |
|
graph_module, qconfig_spec=self.qconfig_spec, dtype=self.dtype, mapping=self.mapping, inplace=False |
|
) |
|
return quantized_module |
|
|
|
|
|
model_path = "./models/all-MiniLM-L6-v2" |
|
model = AutoModel.from_pretrained(model_path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
|
|
input_names = ["input_ids", "attention_mask"] |
|
traced_model = symbolic_trace(model, input_names=input_names) |
|
|
|
|
|
transformation = DynamicQuantization(dtype=torch.qint8) |
|
quantized_model = transformation(traced_model) |
|
|
|
print(type(quantized_model.)) |
|
|
|
|
|
|
|
|
|
|
|
|