#!/usr/bin/env python
import argparse
import torch


def main():
    parser = argparse.ArgumentParser(
        description="Release an OpenNMT-py model for inference"
    )
    parser.add_argument("--model", "-m", help="The model path", required=True)
    parser.add_argument("--output", "-o", help="The output path", required=True)
    parser.add_argument(
        "--format",
        choices=["pytorch", "ctranslate2"],
        default="pytorch",
        help="The format of the released model",
    )
    parser.add_argument(
        "--quantization",
        "-q",
        choices=["int8", "int16", "float16", "int8_float16"],
        default=None,
        help="Quantization type for CT2 model.",
    )
    opt = parser.parse_args()

    model = torch.load(opt.model, map_location=torch.device("cpu"))
    if opt.format == "pytorch":
        model["optim"] = None
        torch.save(model, opt.output)
    elif opt.format == "ctranslate2":
        import ctranslate2

        converter = ctranslate2.converters.OpenNMTPyConverter(opt.model)
        converter.convert(opt.output, force=True, quantization=opt.quantization)


if __name__ == "__main__":
    main()