#!/usr/bin/env python import argparse import torch def main(): parser = argparse.ArgumentParser( description="Release an OpenNMT-py model for inference" ) parser.add_argument("--model", "-m", help="The model path", required=True) parser.add_argument("--output", "-o", help="The output path", required=True) parser.add_argument( "--format", choices=["pytorch", "ctranslate2"], default="pytorch", help="The format of the released model", ) parser.add_argument( "--quantization", "-q", choices=["int8", "int16", "float16", "int8_float16"], default=None, help="Quantization type for CT2 model.", ) opt = parser.parse_args() model = torch.load(opt.model, map_location=torch.device("cpu")) if opt.format == "pytorch": model["optim"] = None torch.save(model, opt.output) elif opt.format == "ctranslate2": import ctranslate2 converter = ctranslate2.converters.OpenNMTPyConverter(opt.model) converter.convert(opt.output, force=True, quantization=opt.quantization) if __name__ == "__main__": main()