|
|
|
import json |
|
import os |
|
import shutil |
|
from dataclasses import dataclass, field |
|
from typing import Optional, Set |
|
from tqdm import tqdm |
|
|
|
from transformers import ( |
|
AutoConfig, |
|
AutoTokenizer, |
|
HfArgumentParser |
|
) |
|
|
|
import onnx |
|
from optimum.exporters.onnx import main_export, export_models |
|
from optimum.exporters.tasks import TasksManager |
|
from onnxruntime.quantization import ( |
|
quantize_dynamic, |
|
QuantType |
|
) |
|
|
|
DEFAULT_QUANTIZE_PARAMS = { |
|
'per_channel': True, |
|
'reduce_range': True, |
|
} |
|
|
|
MODEL_SPECIFIC_QUANTIZE_PARAMS = { |
|
'whisper': { |
|
'per_channel': False, |
|
'reduce_range': False, |
|
} |
|
} |
|
|
|
MODELS_WITHOUT_TOKENIZERS = [ |
|
'wav2vec2' |
|
] |
|
|
|
|
|
@dataclass |
|
class ConversionArguments: |
|
""" |
|
Arguments used for converting HuggingFace models to onnx. |
|
""" |
|
|
|
model_id: str = field( |
|
metadata={ |
|
"help": "Model identifier" |
|
} |
|
) |
|
quantize: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to quantize the model." |
|
} |
|
) |
|
output_parent_dir: str = field( |
|
default='./models/', |
|
metadata={ |
|
"help": "Path where the converted model will be saved to." |
|
} |
|
) |
|
|
|
task: Optional[str] = field( |
|
default='auto', |
|
metadata={ |
|
"help": ( |
|
"The task to export the model for. If not specified, the task will be auto-inferred based on the model. Available tasks depend on the model, but are among:" |
|
f" {str(list(TasksManager._TASKS_TO_AUTOMODELS.keys()))}. For decoder models, use `xxx-with-past` to export the model using past key values in the decoder." |
|
) |
|
} |
|
) |
|
|
|
opset: int = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"If specified, ONNX opset version to export the model with. Otherwise, the default opset will be used." |
|
) |
|
} |
|
) |
|
|
|
device: str = field( |
|
default='cpu', |
|
metadata={ |
|
"help": 'The device to use to do the export.' |
|
} |
|
) |
|
skip_validation: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to skip validation of the converted model" |
|
} |
|
) |
|
|
|
per_channel: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to quantize weights per channel" |
|
} |
|
) |
|
reduce_range: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to quantize weights with 7-bits. It may improve the accuracy for some models running on non-VNNI machine, especially for per-channel mode" |
|
} |
|
) |
|
|
|
output_attentions: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to output attentions from the model. NOTE: This is only supported for whisper models right now." |
|
} |
|
) |
|
|
|
split_modalities: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to split multimodal models. NOTE: This is only supported for CLIP models right now." |
|
} |
|
) |
|
|
|
|
|
def get_operators(model: onnx.ModelProto) -> Set[str]: |
|
operators = set() |
|
|
|
def traverse_graph(graph): |
|
for node in graph.node: |
|
operators.add(node.op_type) |
|
for attr in node.attribute: |
|
if attr.type == onnx.AttributeProto.GRAPH: |
|
subgraph = attr.g |
|
traverse_graph(subgraph) |
|
|
|
traverse_graph(model.graph) |
|
return operators |
|
|
|
|
|
def quantize(model_names_or_paths, **quantize_kwargs): |
|
""" |
|
Quantize the weights of the model from float32 to int8 to allow very efficient inference on modern CPU |
|
|
|
Uses unsigned ints for activation values, signed ints for weights, per |
|
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection |
|
it is faster on most CPU architectures |
|
Args: |
|
onnx_model_path: Path to location the exported ONNX model is stored |
|
Returns: The Path generated for the quantized |
|
""" |
|
|
|
quantize_config = dict( |
|
**quantize_kwargs, |
|
per_model_config={} |
|
) |
|
|
|
for model in tqdm(model_names_or_paths, desc='Quantizing'): |
|
directory_path = os.path.dirname(model) |
|
file_name_without_extension = os.path.splitext( |
|
os.path.basename(model))[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loaded_model = onnx.load_model(model) |
|
op_types = get_operators(loaded_model) |
|
weight_type = QuantType.QUInt8 if 'Conv' in op_types else QuantType.QInt8 |
|
|
|
quantize_dynamic( |
|
model_input=model, |
|
model_output=os.path.join( |
|
directory_path, f'{file_name_without_extension}_quantized.onnx'), |
|
|
|
weight_type=weight_type, |
|
optimize_model=False, |
|
|
|
|
|
|
|
extra_options=dict( |
|
EnableSubgraph=True |
|
), |
|
**quantize_kwargs |
|
) |
|
|
|
quantize_config['per_model_config'][file_name_without_extension] = dict( |
|
op_types=list(op_types), |
|
weight_type=str(weight_type), |
|
) |
|
|
|
|
|
with open(os.path.join(directory_path, 'quantize_config.json'), 'w') as fp: |
|
json.dump(quantize_config, fp, indent=4) |
|
|
|
|
|
def main(): |
|
""" |
|
Example usage: |
|
python quantize.py --model_id sentence-transformers/all-MiniLM-L6-v2-unquantized --quantize --task default |
|
""" |
|
parser = HfArgumentParser( |
|
(ConversionArguments, ) |
|
) |
|
conv_args, = parser.parse_args_into_dataclasses() |
|
|
|
model_id = conv_args.model_id |
|
|
|
output_model_folder = os.path.join(conv_args.output_parent_dir, model_id) |
|
|
|
|
|
os.makedirs(output_model_folder, exist_ok=True) |
|
|
|
|
|
config = AutoConfig.from_pretrained(model_id) |
|
|
|
tokenizer = None |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
except KeyError: |
|
pass |
|
|
|
except Exception as e: |
|
if config.model_type not in MODELS_WITHOUT_TOKENIZERS: |
|
raise e |
|
|
|
|
|
export_kwargs = dict( |
|
model_name_or_path=model_id, |
|
output=output_model_folder, |
|
task=conv_args.task, |
|
opset=conv_args.opset, |
|
device=conv_args.device, |
|
do_validation=not conv_args.skip_validation, |
|
) |
|
|
|
|
|
|
|
main_export(**export_kwargs) |
|
|
|
|
|
|
|
if conv_args.quantize: |
|
|
|
quantize_config = MODEL_SPECIFIC_QUANTIZE_PARAMS.get( |
|
config.model_type, DEFAULT_QUANTIZE_PARAMS) |
|
|
|
quantize([ |
|
os.path.join(output_model_folder, x) |
|
for x in os.listdir(output_model_folder) |
|
if x.endswith('.onnx') and not x.endswith('_quantized.onnx') |
|
], **quantize_config) |
|
|
|
|
|
os.makedirs(os.path.join(output_model_folder, 'onnx'), exist_ok=True) |
|
for file in os.listdir(output_model_folder): |
|
if file.endswith(('.onnx', '.onnx_data')): |
|
shutil.move(os.path.join(output_model_folder, file), |
|
os.path.join(output_model_folder, 'onnx', file)) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|