|
|
|
|
|
|
|
|
|
|
|
|
|
"""Neural machine translation model decoding script.""" |
|
|
|
import configargparse |
|
import logging |
|
import os |
|
import random |
|
import sys |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
def get_parser(): |
|
"""Get default arguments.""" |
|
parser = configargparse.ArgumentParser( |
|
description="Translate text from speech " |
|
"using a speech translation model on one CPU or GPU", |
|
config_file_parser_class=configargparse.YAMLConfigFileParser, |
|
formatter_class=configargparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
|
|
parser.add("--config", is_config_file=True, help="Config file path") |
|
parser.add( |
|
"--config2", |
|
is_config_file=True, |
|
help="Second config file path that overwrites the settings in `--config`", |
|
) |
|
parser.add( |
|
"--config3", |
|
is_config_file=True, |
|
help="Third config file path " |
|
"that overwrites the settings in `--config` and `--config2`", |
|
) |
|
|
|
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") |
|
parser.add_argument( |
|
"--dtype", |
|
choices=("float16", "float32", "float64"), |
|
default="float32", |
|
help="Float precision (only available in --api v2)", |
|
) |
|
parser.add_argument( |
|
"--backend", |
|
type=str, |
|
default="chainer", |
|
choices=["chainer", "pytorch"], |
|
help="Backend library", |
|
) |
|
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") |
|
parser.add_argument("--seed", type=int, default=1, help="Random seed") |
|
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option") |
|
parser.add_argument( |
|
"--batchsize", |
|
type=int, |
|
default=1, |
|
help="Batch size for beam search (0: means no batch processing)", |
|
) |
|
parser.add_argument( |
|
"--preprocess-conf", |
|
type=str, |
|
default=None, |
|
help="The configuration file for the pre-processing", |
|
) |
|
parser.add_argument( |
|
"--api", |
|
default="v1", |
|
choices=["v1", "v2"], |
|
help="Beam search APIs " |
|
"v1: Default API. It only supports " |
|
"the ASRInterface.recognize method and DefaultRNNLM. " |
|
"v2: Experimental API. " |
|
"It supports any models that implements ScorerInterface.", |
|
) |
|
|
|
parser.add_argument( |
|
"--trans-json", type=str, help="Filename of translation data (json)" |
|
) |
|
parser.add_argument( |
|
"--result-label", |
|
type=str, |
|
required=True, |
|
help="Filename of result label data (json)", |
|
) |
|
|
|
parser.add_argument( |
|
"--model", type=str, required=True, help="Model file parameters to read" |
|
) |
|
parser.add_argument( |
|
"--model-conf", type=str, default=None, help="Model config file" |
|
) |
|
|
|
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") |
|
parser.add_argument("--beam-size", type=int, default=1, help="Beam size") |
|
parser.add_argument("--penalty", type=float, default=0.1, help="Incertion penalty") |
|
parser.add_argument( |
|
"--maxlenratio", |
|
type=float, |
|
default=3.0, |
|
help="""Input length ratio to obtain max output length. |
|
If maxlenratio=0.0 (default), it uses a end-detect function |
|
to automatically find maximum hypothesis lengths""", |
|
) |
|
parser.add_argument( |
|
"--minlenratio", |
|
type=float, |
|
default=0.0, |
|
help="Input length ratio to obtain min output length", |
|
) |
|
|
|
parser.add_argument( |
|
"--tgt-lang", |
|
default=False, |
|
type=str, |
|
help="target language ID (e.g., <en>, <de>, and <fr> etc.)", |
|
) |
|
return parser |
|
|
|
|
|
def main(args): |
|
"""Run the main decoding function.""" |
|
parser = get_parser() |
|
args = parser.parse_args(args) |
|
|
|
|
|
if args.verbose == 1: |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
elif args.verbose == 2: |
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
else: |
|
logging.basicConfig( |
|
level=logging.WARN, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
logging.warning("Skip DEBUG/INFO messages") |
|
|
|
|
|
if args.ngpu > 0: |
|
cvd = os.environ.get("CUDA_VISIBLE_DEVICES") |
|
if cvd is None: |
|
logging.warning("CUDA_VISIBLE_DEVICES is not set.") |
|
elif args.ngpu != len(cvd.split(",")): |
|
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") |
|
sys.exit(1) |
|
|
|
|
|
if args.ngpu > 1: |
|
logging.error("The program only supports ngpu=1.") |
|
sys.exit(1) |
|
|
|
|
|
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) |
|
|
|
|
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
logging.info("set random seed = %d" % args.seed) |
|
|
|
|
|
logging.info("backend = " + args.backend) |
|
if args.backend == "pytorch": |
|
|
|
from espnet.mt.pytorch_backend.mt import trans |
|
|
|
if args.dtype != "float32": |
|
raise NotImplementedError( |
|
f"`--dtype {args.dtype}` is only available with `--api v2`" |
|
) |
|
trans(args) |
|
else: |
|
raise ValueError("Only pytorch are supported.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(sys.argv[1:]) |
|
|