# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from argparse import Namespace from copy import deepcopy from pathlib import Path from typing import Dict, Optional from fairseq.data import Dictionary logger = logging.getLogger(__name__) def get_config_from_yaml(yaml_path: Path): try: import yaml except ImportError: print("Please install PyYAML: pip install PyYAML") config = {} if yaml_path.is_file(): try: with open(yaml_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) except Exception as e: raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}") else: raise FileNotFoundError(f"{yaml_path.as_posix()} not found") return config class S2TDataConfig(object): """Wrapper class for data config YAML""" def __init__(self, yaml_path: Path): self.config = get_config_from_yaml(yaml_path) self.root = yaml_path.parent def _auto_convert_to_abs_path(self, x): if isinstance(x, str): if not Path(x).exists() and (self.root / x).exists(): return (self.root / x).as_posix() elif isinstance(x, dict): return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()} return x @property def vocab_filename(self): """fairseq vocabulary file under data root""" return self.config.get("vocab_filename", "dict.txt") @property def speaker_set_filename(self): """speaker set file under data root""" return self.config.get("speaker_set_filename", None) @property def shuffle(self) -> bool: """Shuffle dataset samples before batching""" return self.config.get("shuffle", False) @property def pre_tokenizer(self) -> Dict: """Pre-tokenizer to apply before subword tokenization. Returning a dictionary with `tokenizer` providing the tokenizer name and the other items providing the tokenizer-specific arguments. Tokenizers are defined in `fairseq.data.encoders.*`""" tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None}) return self._auto_convert_to_abs_path(tokenizer) @property def bpe_tokenizer(self) -> Dict: """Subword tokenizer to apply after pre-tokenization. Returning a dictionary with `bpe` providing the tokenizer name and the other items providing the tokenizer-specific arguments. Tokenizers are defined in `fairseq.data.encoders.*`""" tokenizer = self.config.get("bpe_tokenizer", {"bpe": None}) return self._auto_convert_to_abs_path(tokenizer) @property def prepend_tgt_lang_tag(self) -> bool: """Prepend target lang ID token as the target BOS (e.g. for to-many multilingual setting). During inference, this requires `--prefix-size 1` to force BOS to be lang ID token.""" return self.config.get("prepend_tgt_lang_tag", False) @property def prepend_bos_and_append_tgt_lang_tag(self) -> bool: """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) @property def input_feat_per_channel(self): """The dimension of input features (per audio channel)""" return self.config.get("input_feat_per_channel", 80) @property def input_channels(self): """The number of channels in the input audio""" return self.config.get("input_channels", 1) @property def sample_rate(self): return self.config.get("sample_rate", 16_000) @property def sampling_alpha(self): """Hyper-parameter alpha = 1/T for temperature-based resampling. (alpha = 1 for no resampling)""" return self.config.get("sampling_alpha", 1.0) @property def use_audio_input(self): """Needed by the dataset loader to see if the model requires raw audio as inputs.""" return self.config.get("use_audio_input", False) def standardize_audio(self) -> bool: return self.use_audio_input and self.config.get("standardize_audio", False) @property def use_sample_rate(self): """Needed by the dataset loader to see if the model requires raw audio with specific sample rate as inputs.""" return self.config.get("use_sample_rate", 16000) @property def audio_root(self): """Audio paths in the manifest TSV can be relative and this provides the root path. Set this to empty string when using absolute paths.""" return self.config.get("audio_root", "") def get_transforms(self, transform_type, split, is_train): """Split-specific feature transforms. Allowing train set wildcard `_train`, evaluation set wildcard `_eval` and general wildcard `*` for matching.""" from copy import deepcopy cfg = deepcopy(self.config) _cur = cfg.get(f"{transform_type}transforms", {}) cur = _cur.get(split) cur = _cur.get("_train") if cur is None and is_train else cur cur = _cur.get("_eval") if cur is None and not is_train else cur cur = _cur.get("*") if cur is None else cur return cur def get_feature_transforms(self, split, is_train): cfg = deepcopy(self.config) # TODO: deprecate transforms cur = self.get_transforms("", split, is_train) if cur is not None: logger.warning( "Auto converting transforms into feature_transforms, " "but transforms will be deprecated in the future. Please " "update this in the config." ) ft_transforms = self.get_transforms("feature_", split, is_train) if ft_transforms: cur.extend(ft_transforms) else: cur = self.get_transforms("feature_", split, is_train) cfg["feature_transforms"] = cur return cfg def get_waveform_transforms(self, split, is_train): cfg = deepcopy(self.config) cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train) return cfg def get_dataset_transforms(self, split, is_train): cfg = deepcopy(self.config) cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train) return cfg @property def global_cmvn_stats_npz(self) -> Optional[str]: path = self.config.get("global_cmvn", {}).get("stats_npz_path", None) return self._auto_convert_to_abs_path(path) @property def vocoder(self) -> Dict[str, str]: vocoder = self.config.get("vocoder", {"type": "griffin_lim"}) return self._auto_convert_to_abs_path(vocoder) @property def hub(self) -> Dict[str, str]: return self.config.get("hub", {}) class S2SDataConfig(S2TDataConfig): """Wrapper class for data config YAML""" @property def vocab_filename(self): """fairseq vocabulary file under data root""" return self.config.get("vocab_filename", None) @property def pre_tokenizer(self) -> Dict: return None @property def bpe_tokenizer(self) -> Dict: return None @property def input_transformed_channels(self): """The number of channels in the audio after feature transforms""" # TODO: move this into individual transforms # TODO: deprecate transforms _cur = self.config.get("transforms", {}) ft_transforms = self.config.get("feature_transforms", {}) if _cur and ft_transforms: _cur.update(ft_transforms) else: _cur = self.config.get("feature_transforms", {}) cur = _cur.get("_train", []) _channels = self.input_channels if "delta_deltas" in cur: _channels *= 3 return _channels @property def output_sample_rate(self): """The audio sample rate of output target speech""" return self.config.get("output_sample_rate", 22050) @property def target_speaker_embed(self): """Target speaker embedding file (one line per target audio sample)""" return self.config.get("target_speaker_embed", None) @property def prepend_tgt_lang_tag_as_bos(self) -> bool: """Prepend target lang ID token as the target BOS.""" return self.config.get("prepend_tgt_lang_tag_as_bos", False) class MultitaskConfig(object): """Wrapper class for data config YAML""" def __init__(self, yaml_path: Path): config = get_config_from_yaml(yaml_path) self.config = {} for k, v in config.items(): self.config[k] = SingleTaskConfig(k, v) def get_all_tasks(self): return self.config def get_single_task(self, name): assert name in self.config, f"multitask '{name}' does not exist!" return self.config[name] @property def first_pass_decoder_task_index(self): """Return the task index of the first-pass text decoder. If there are multiple 'is_first_pass_decoder: True' in the config file, the last task is used for the first-pass decoder. If there is no 'is_first_pass_decoder: True' in the config file, the last task whose task_name includes 'target' and decoder_type is not ctc. """ idx = -1 for i, (k, v) in enumerate(self.config.items()): if v.is_first_pass_decoder: idx = i if idx < 0: for i, (k, v) in enumerate(self.config.items()): if k.startswith("target") and v.decoder_type == "transformer": idx = i return idx class SingleTaskConfig(object): def __init__(self, name, config): self.task_name = name self.config = config dict_path = config.get("dict", "") self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None @property def data(self): return self.config.get("data", "") @property def decoder_type(self): return self.config.get("decoder_type", "transformer") @property def decoder_args(self): """Decoder arch related args""" args = self.config.get("decoder_args", {}) return Namespace(**args) @property def criterion_cfg(self): """cfg for the multitask criterion""" if self.decoder_type == "ctc": from fairseq.criterions.ctc import CtcCriterionConfig cfg = CtcCriterionConfig cfg.zero_infinity = self.config.get("zero_infinity", True) else: from fairseq.criterions.label_smoothed_cross_entropy import ( LabelSmoothedCrossEntropyCriterionConfig, ) cfg = LabelSmoothedCrossEntropyCriterionConfig cfg.label_smoothing = self.config.get("label_smoothing", 0.2) return cfg @property def input_from(self): """Condition on encoder/decoder of the main model""" return "decoder" if "decoder_layer" in self.config else "encoder" @property def input_layer(self): if self.input_from == "decoder": return self.config["decoder_layer"] - 1 else: # default using the output from the last encoder layer (-1) return self.config.get("encoder_layer", 0) - 1 @property def loss_weight_schedule(self): return ( "decay" if "loss_weight_max" in self.config and "loss_weight_decay_steps" in self.config else "fixed" ) def get_loss_weight(self, num_updates): if self.loss_weight_schedule == "fixed": weight = self.config.get("loss_weight", 1.0) else: # "decay" assert ( self.config.get("loss_weight_decay_steps", 0) > 0 ), "loss_weight_decay_steps must be greater than 0 for a decay schedule" loss_weight_min = self.config.get("loss_weight_min", 0.0001) loss_weight_decay_stepsize = ( self.config["loss_weight_max"] - loss_weight_min ) / self.config["loss_weight_decay_steps"] weight = max( self.config["loss_weight_max"] - loss_weight_decay_stepsize * num_updates, loss_weight_min, ) return weight @property def prepend_bos_and_append_tgt_lang_tag(self) -> bool: """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) @property def eos_token(self): """EOS token during generation""" return self.config.get("eos_token", "") @property def rdrop_alpha(self): return self.config.get("rdrop_alpha", 0.0) @property def is_first_pass_decoder(self): flag = self.config.get("is_first_pass_decoder", False) if flag: if self.decoder_type == "ctc": raise ValueError( "First-pass decoder in the multi-decoder model must not be CTC." ) if "target" not in self.task_name: raise Warning( 'The name of the first-pass decoder does not include "target".' ) return flag @property def get_lang_tag_mapping(self): return self.config.get("lang_tag_mapping", {})