|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from pathlib import Path |
|
from typing import Callable, List, Optional, Union |
|
|
|
import torch |
|
from fairseq import utils |
|
from fairseq.data.indexed_dataset import get_available_dataset_impl |
|
from fairseq.dataclass.configs import ( |
|
CheckpointConfig, |
|
CommonConfig, |
|
CommonEvalConfig, |
|
DatasetConfig, |
|
DistributedTrainingConfig, |
|
EvalLMConfig, |
|
GenerationConfig, |
|
InteractiveConfig, |
|
OptimizationConfig, |
|
EMAConfig, |
|
) |
|
from fairseq.dataclass.utils import gen_parser_from_dataclass |
|
|
|
|
|
from fairseq.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list |
|
|
|
|
|
def get_preprocessing_parser(default_task="translation"): |
|
parser = get_parser("Preprocessing", default_task) |
|
add_preprocess_args(parser) |
|
return parser |
|
|
|
|
|
def get_training_parser(default_task="translation"): |
|
parser = get_parser("Trainer", default_task) |
|
add_dataset_args(parser, train=True) |
|
add_distributed_training_args(parser) |
|
add_model_args(parser) |
|
add_optimization_args(parser) |
|
add_checkpoint_args(parser) |
|
add_ema_args(parser) |
|
return parser |
|
|
|
|
|
def get_generation_parser(interactive=False, default_task="translation"): |
|
parser = get_parser("Generation", default_task) |
|
add_dataset_args(parser, gen=True) |
|
add_distributed_training_args(parser, default_world_size=1) |
|
add_generation_args(parser) |
|
add_checkpoint_args(parser) |
|
if interactive: |
|
add_interactive_args(parser) |
|
return parser |
|
|
|
|
|
def get_speech_generation_parser(default_task="text_to_speech"): |
|
parser = get_parser("Speech Generation", default_task) |
|
add_dataset_args(parser, gen=True) |
|
add_distributed_training_args(parser, default_world_size=1) |
|
add_speech_generation_args(parser) |
|
return parser |
|
|
|
|
|
def get_interactive_generation_parser(default_task="translation"): |
|
return get_generation_parser(interactive=True, default_task=default_task) |
|
|
|
|
|
def get_eval_lm_parser(default_task="language_modeling"): |
|
parser = get_parser("Evaluate Language Model", default_task) |
|
add_dataset_args(parser, gen=True) |
|
add_distributed_training_args(parser, default_world_size=1) |
|
add_eval_lm_args(parser) |
|
return parser |
|
|
|
|
|
def get_validation_parser(default_task=None): |
|
parser = get_parser("Validation", default_task) |
|
add_dataset_args(parser, train=True) |
|
add_distributed_training_args(parser, default_world_size=1) |
|
group = parser.add_argument_group("Evaluation") |
|
gen_parser_from_dataclass(group, CommonEvalConfig()) |
|
return parser |
|
|
|
|
|
def parse_args_and_arch( |
|
parser: argparse.ArgumentParser, |
|
input_args: List[str] = None, |
|
parse_known: bool = False, |
|
suppress_defaults: bool = False, |
|
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, |
|
): |
|
""" |
|
Args: |
|
parser (ArgumentParser): the parser |
|
input_args (List[str]): strings to parse, defaults to sys.argv |
|
parse_known (bool): only parse known arguments, similar to |
|
`ArgumentParser.parse_known_args` |
|
suppress_defaults (bool): parse while ignoring all default values |
|
modify_parser (Optional[Callable[[ArgumentParser], None]]): |
|
function to modify the parser, e.g., to set default values |
|
""" |
|
if suppress_defaults: |
|
|
|
|
|
|
|
args = parse_args_and_arch( |
|
parser, |
|
input_args=input_args, |
|
parse_known=parse_known, |
|
suppress_defaults=False, |
|
) |
|
suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) |
|
suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) |
|
args = suppressed_parser.parse_args(input_args) |
|
return argparse.Namespace( |
|
**{k: v for k, v in vars(args).items() if v is not None} |
|
) |
|
|
|
from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY |
|
|
|
|
|
|
|
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) |
|
usr_parser.add_argument("--user-dir", default=None) |
|
usr_args, _ = usr_parser.parse_known_args(input_args) |
|
utils.import_user_module(usr_args) |
|
|
|
if modify_parser is not None: |
|
modify_parser(parser) |
|
|
|
|
|
|
|
|
|
|
|
args, _ = parser.parse_known_args(input_args) |
|
|
|
|
|
if hasattr(args, "arch"): |
|
model_specific_group = parser.add_argument_group( |
|
"Model-specific configuration", |
|
|
|
|
|
argument_default=argparse.SUPPRESS, |
|
) |
|
if args.arch in ARCH_MODEL_REGISTRY: |
|
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) |
|
elif args.arch in MODEL_REGISTRY: |
|
MODEL_REGISTRY[args.arch].add_args(model_specific_group) |
|
else: |
|
raise RuntimeError() |
|
|
|
if hasattr(args, "task"): |
|
from fairseq.tasks import TASK_REGISTRY |
|
|
|
TASK_REGISTRY[args.task].add_args(parser) |
|
if getattr(args, "use_bmuf", False): |
|
|
|
from fairseq.optim.bmuf import FairseqBMUF |
|
|
|
FairseqBMUF.add_args(parser) |
|
|
|
|
|
from fairseq.registry import REGISTRIES |
|
|
|
for registry_name, REGISTRY in REGISTRIES.items(): |
|
choice = getattr(args, registry_name, None) |
|
if choice is not None: |
|
cls = REGISTRY["registry"][choice] |
|
if hasattr(cls, "add_args"): |
|
cls.add_args(parser) |
|
elif hasattr(cls, "__dataclass"): |
|
gen_parser_from_dataclass(parser, cls.__dataclass()) |
|
|
|
|
|
if modify_parser is not None: |
|
modify_parser(parser) |
|
|
|
|
|
if parse_known: |
|
args, extra = parser.parse_known_args(input_args) |
|
else: |
|
args = parser.parse_args(input_args) |
|
extra = None |
|
|
|
if ( |
|
hasattr(args, "batch_size_valid") and args.batch_size_valid is None |
|
) or not hasattr(args, "batch_size_valid"): |
|
args.batch_size_valid = args.batch_size |
|
if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: |
|
args.max_tokens_valid = args.max_tokens |
|
if getattr(args, "memory_efficient_fp16", False): |
|
args.fp16 = True |
|
if getattr(args, "memory_efficient_bf16", False): |
|
args.bf16 = True |
|
args.tpu = getattr(args, "tpu", False) |
|
args.bf16 = getattr(args, "bf16", False) |
|
if args.bf16: |
|
args.tpu = True |
|
if args.tpu and args.fp16: |
|
raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") |
|
|
|
if getattr(args, "seed", None) is None: |
|
args.seed = 1 |
|
args.no_seed_provided = True |
|
else: |
|
args.no_seed_provided = False |
|
|
|
if getattr(args, "update_epoch_batch_itr", None) is None: |
|
if hasattr(args, "grouped_shuffling"): |
|
args.update_epoch_batch_itr = args.grouped_shuffling |
|
else: |
|
args.grouped_shuffling = False |
|
args.update_epoch_batch_itr = False |
|
|
|
|
|
if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: |
|
ARCH_CONFIG_REGISTRY[args.arch](args) |
|
|
|
if parse_known: |
|
return args, extra |
|
else: |
|
return args |
|
|
|
|
|
def get_parser(desc, default_task="translation"): |
|
|
|
|
|
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) |
|
usr_parser.add_argument("--user-dir", default=None) |
|
usr_args, _ = usr_parser.parse_known_args() |
|
utils.import_user_module(usr_args) |
|
|
|
parser = argparse.ArgumentParser(allow_abbrev=False) |
|
gen_parser_from_dataclass(parser, CommonConfig()) |
|
|
|
from fairseq.registry import REGISTRIES |
|
|
|
for registry_name, REGISTRY in REGISTRIES.items(): |
|
parser.add_argument( |
|
"--" + registry_name.replace("_", "-"), |
|
default=REGISTRY["default"], |
|
choices=REGISTRY["registry"].keys(), |
|
) |
|
|
|
|
|
from fairseq.tasks import TASK_REGISTRY |
|
|
|
parser.add_argument( |
|
"--task", |
|
metavar="TASK", |
|
default=default_task, |
|
choices=TASK_REGISTRY.keys(), |
|
help="task", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def add_preprocess_args(parser): |
|
group = parser.add_argument_group("Preprocessing") |
|
|
|
group.add_argument("-s", "--source-lang", default=None, metavar="SRC", |
|
help="source language") |
|
group.add_argument("-t", "--target-lang", default=None, metavar="TARGET", |
|
help="target language") |
|
group.add_argument("--trainpref", metavar="FP", default=None, |
|
help="train file prefix (also used to build dictionaries)") |
|
group.add_argument("--validpref", metavar="FP", default=None, |
|
help="comma separated, valid file prefixes " |
|
"(words missing from train set are replaced with <unk>)") |
|
group.add_argument("--testpref", metavar="FP", default=None, |
|
help="comma separated, test file prefixes " |
|
"(words missing from train set are replaced with <unk>)") |
|
group.add_argument("--align-suffix", metavar="FP", default=None, |
|
help="alignment file suffix") |
|
group.add_argument("--destdir", metavar="DIR", default="data-bin", |
|
help="destination dir") |
|
group.add_argument("--thresholdtgt", metavar="N", default=0, type=int, |
|
help="map words appearing less than threshold times to unknown") |
|
group.add_argument("--thresholdsrc", metavar="N", default=0, type=int, |
|
help="map words appearing less than threshold times to unknown") |
|
group.add_argument("--tgtdict", metavar="FP", |
|
help="reuse given target dictionary") |
|
group.add_argument("--srcdict", metavar="FP", |
|
help="reuse given source dictionary") |
|
group.add_argument("--nwordstgt", metavar="N", default=-1, type=int, |
|
help="number of target words to retain") |
|
group.add_argument("--nwordssrc", metavar="N", default=-1, type=int, |
|
help="number of source words to retain") |
|
group.add_argument("--alignfile", metavar="ALIGN", default=None, |
|
help="an alignment file (optional)") |
|
parser.add_argument('--dataset-impl', metavar='FORMAT', default='mmap', |
|
choices=get_available_dataset_impl(), |
|
help='output dataset implementation') |
|
group.add_argument("--joined-dictionary", action="store_true", |
|
help="Generate joined dictionary") |
|
group.add_argument("--only-source", action="store_true", |
|
help="Only process the source language") |
|
group.add_argument("--padding-factor", metavar="N", default=8, type=int, |
|
help="Pad dictionary size to be multiple of N") |
|
group.add_argument("--workers", metavar="N", default=1, type=int, |
|
help="number of parallel workers") |
|
group.add_argument("--dict-only", action='store_true', |
|
help="if true, only builds a dictionary and then exits") |
|
|
|
return parser |
|
|
|
|
|
def add_dataset_args(parser, train=False, gen=False): |
|
group = parser.add_argument_group("dataset_data_loading") |
|
gen_parser_from_dataclass(group, DatasetConfig()) |
|
|
|
return group |
|
|
|
|
|
def add_distributed_training_args(parser, default_world_size=None): |
|
group = parser.add_argument_group("distributed_training") |
|
if default_world_size is None: |
|
default_world_size = max(1, torch.cuda.device_count()) |
|
gen_parser_from_dataclass( |
|
group, DistributedTrainingConfig(distributed_world_size=default_world_size) |
|
) |
|
return group |
|
|
|
|
|
def add_optimization_args(parser): |
|
group = parser.add_argument_group("optimization") |
|
|
|
gen_parser_from_dataclass(group, OptimizationConfig()) |
|
|
|
return group |
|
|
|
|
|
def add_checkpoint_args(parser): |
|
group = parser.add_argument_group("checkpoint") |
|
|
|
gen_parser_from_dataclass(group, CheckpointConfig()) |
|
|
|
return group |
|
|
|
|
|
def add_common_eval_args(group): |
|
gen_parser_from_dataclass(group, CommonEvalConfig()) |
|
|
|
|
|
def add_eval_lm_args(parser): |
|
group = parser.add_argument_group("LM Evaluation") |
|
add_common_eval_args(group) |
|
gen_parser_from_dataclass(group, EvalLMConfig()) |
|
|
|
|
|
def add_generation_args(parser): |
|
group = parser.add_argument_group("Generation") |
|
add_common_eval_args(group) |
|
gen_parser_from_dataclass(group, GenerationConfig()) |
|
return group |
|
|
|
|
|
def add_speech_generation_args(parser): |
|
group = parser.add_argument_group("Speech Generation") |
|
add_common_eval_args(group) |
|
|
|
group.add_argument('--eos_prob_threshold', default=0.5, type=float, |
|
help='terminate when eos probability exceeds this') |
|
|
|
return group |
|
|
|
|
|
def add_interactive_args(parser): |
|
group = parser.add_argument_group("Interactive") |
|
gen_parser_from_dataclass(group, InteractiveConfig()) |
|
|
|
|
|
def add_model_args(parser): |
|
group = parser.add_argument_group("Model configuration") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fairseq.models import ARCH_MODEL_REGISTRY |
|
group.add_argument('--arch', '-a', metavar='ARCH', |
|
choices=ARCH_MODEL_REGISTRY.keys(), |
|
help='model architecture') |
|
|
|
return group |
|
|
|
|
|
def get_args( |
|
data: Union[str, Path], |
|
task: str = "translation", |
|
arch: str = "transformer", |
|
**overrides |
|
): |
|
parser = get_training_parser(task) |
|
args = parse_args_and_arch(parser, [str(data), "--task", task, "--arch", arch]) |
|
|
|
for k, v in overrides.items(): |
|
setattr(args, k, v) |
|
|
|
return args |
|
|
|
|
|
def add_ema_args(parser): |
|
group = parser.add_argument_group("EMA configuration") |
|
gen_parser_from_dataclass(group, EMAConfig()) |
|
|