|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
|
|
import torch |
|
|
|
from models.vocoders.vocoder_inference import VocoderInference |
|
from utils.util import load_config |
|
|
|
|
|
def build_inference(args, cfg, infer_type="infer_from_dataset"): |
|
supported_inference = { |
|
"GANVocoder": VocoderInference, |
|
"DiffusionVocoder": VocoderInference, |
|
} |
|
|
|
inference_class = supported_inference[cfg.model_type] |
|
return inference_class(args, cfg, infer_type) |
|
|
|
|
|
def cuda_relevant(deterministic=False): |
|
torch.cuda.empty_cache() |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
torch.backends.cudnn.deterministic = deterministic |
|
torch.backends.cudnn.benchmark = not deterministic |
|
torch.use_deterministic_algorithms(deterministic) |
|
|
|
|
|
def build_parser(): |
|
r"""Build argument parser for inference.py. |
|
Anything else should be put in an extra config YAML file. |
|
""" |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", |
|
type=str, |
|
required=True, |
|
help="JSON/YAML file for configurations.", |
|
) |
|
parser.add_argument( |
|
"--infer_mode", |
|
type=str, |
|
required=None, |
|
) |
|
parser.add_argument( |
|
"--infer_datasets", |
|
nargs="+", |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--feature_folder", |
|
type=str, |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--audio_folder", |
|
type=str, |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--vocoder_dir", |
|
type=str, |
|
required=True, |
|
help="Vocoder checkpoint directory. Searching behavior is the same as " |
|
"the acoustics one.", |
|
) |
|
parser.add_argument( |
|
"--output_dir", |
|
type=str, |
|
default="result", |
|
help="Output directory. Default: ./result", |
|
) |
|
parser.add_argument( |
|
"--log_level", |
|
type=str, |
|
default="warning", |
|
help="Logging level. Default: warning", |
|
) |
|
parser.add_argument( |
|
"--keep_cache", |
|
action="store_true", |
|
default=False, |
|
help="Keep cache files. Only applicable to inference from files.", |
|
) |
|
return parser |
|
|
|
|
|
def main(): |
|
|
|
args = build_parser().parse_args() |
|
|
|
|
|
cfg = load_config(args.config) |
|
|
|
|
|
cuda_relevant() |
|
|
|
|
|
trainer = build_inference(args, cfg, args.infer_mode) |
|
|
|
|
|
trainer.inference() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|