import argparse import os import torch import torchaudio import text import utils.make_html as html from utils.plotting import get_spectrogram_figure from vocoder import load_hifigan from vocoder.hifigan.denoiser import Denoiser from utils import get_basic_config #default: # python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test # Examples: # python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test_fp_adv # python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --denoise 0.01 --out_dir samples/test_fp_adv_d # python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_mse.pth --out_dir samples/test_fp_mse # python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --out_dir samples/test_tc2_adv # python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --denoise 0.01 --out_dir samples/test_tc2_adv_d # python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_mse.pth --out_dir samples/test_tc2_mse def test(args, text_arabic): use_cuda_if_available = not args.cpu device = torch.device( 'cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu') out_dir = args.out_dir sample_rate = 22_050 # Load model if args.model == 'fastpitch': from models.fastpitch import FastPitch model = FastPitch(args.checkpoint) elif args.model == 'tacotron2': from models.tacotron2 import Tacotron2 model = Tacotron2(args.checkpoint) else: raise "model type not supported" print(f'Loaded {args.model} from: {args.checkpoint}') model.eval() # Load vocoder model if args.vocoder_sd is None or args.vocoder_config is None: config = get_basic_config() if args.vocoder_sd is None: args.vocoder_sd = config.vocoder_state_path if args.vocoder_config is None: args.vocoder_config = config.vocoder_config_path vocoder = load_hifigan( state_dict_path=args.vocoder_sd, config_file=args.vocoder_config) print(f'Loaded vocoder from: {args.vocoder_sd}') model, vocoder = model.to(device), vocoder.to(device) denoiser = Denoiser(vocoder) # Infer spectrogram and wave with torch.inference_mode(): mel_spec = model.ttmel(text_arabic, vowelizer=args.vowelizer) wave = vocoder(mel_spec[None]) if args.denoise > 0: wave = denoiser(wave, args.denoise) # Save wave and images if not os.path.exists(out_dir): os.makedirs(out_dir) print(f"Created folder: {out_dir}") torchaudio.save(f'{out_dir}/wave.wav', wave[0].cpu(), sample_rate) get_spectrogram_figure(mel_spec.cpu()).savefig( f'{out_dir}/mel_spec.png') t_phon = text.arabic_to_phonemes(text_arabic) t_phon = text.simplify_phonemes(t_phon.replace(' ', '').replace('+', ' ')) with open(f'{out_dir}/index.html', 'w', encoding='utf-8') as f: f.write(html.make_html_start()) f.write(html.make_h_tag("Test sample", n=1)) f.write(html.make_sample_entry2(f"./wave.wav", text_arabic, t_phon)) f.write(html.make_h_tag("Spectrogram")) f.write(html.make_img_tag('./mel_spec.png')) f.write(html.make_volume_script(0.42)) f.write(html.make_html_end()) print(f"Saved test sample to: {out_dir}") if not args.do_not_play: try: import sounddevice as sd sd.play(wave[0, 0].cpu(), sample_rate, blocking=True) except: pass def main(): parser = argparse.ArgumentParser() parser.add_argument('--text', type=str, default="أَلسَّلامُ عَلَيكُم يا صَديقي") parser.add_argument('--model', type=str, default='fastpitch') parser.add_argument( '--checkpoint', default='pretrained/fastpitch_ar_adv.pth') parser.add_argument('--vocoder_sd', type=str, default=None) parser.add_argument('--vocoder_config', type=str, default=None) parser.add_argument('--denoise', type=float, default=0) parser.add_argument('--out_dir', default='samples/test') parser.add_argument('--vowelizer', default=None) parser.add_argument('--cpu', action='store_true') parser.add_argument('--do_not_play', action='store_true') args = parser.parse_args() text_arabic = args.text test(args, text_arabic) if __name__ == '__main__': main()