import argparse, os, sys, glob import pathlib directory = pathlib.Path(os.getcwd()) print(directory) sys.path.append(str(directory)) import torch import numpy as np from omegaconf import OmegaConf from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler import random, math, librosa from vocoder.bigvgan.models import VocoderBigVGAN import soundfile from pathlib import Path from tqdm import tqdm def load_model_from_config(config, ckpt = None, verbose=True): model = instantiate_from_config(config.model) if ckpt: print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") sd = pl_sd["state_dict"] print(f'---------------------------epoch : {pl_sd["epoch"]}, global step: {pl_sd["global_step"]//1e3}k---------------------------') m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) else: print(f"Note chat no ckpt is loaded !!!") model.cuda() model.eval() return model def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--sample_rate", type=int, default="16000", help="sample rate of wav" ) parser.add_argument( "--length", type=int, default=None, help="length of wav" ) parser.add_argument( "--test-dataset", default="vggsound", help="test which dataset: vggsound/landscape/fsd50k" ) parser.add_argument( "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2audio-samples" ) parser.add_argument( "--ddim_steps", type=int, default=25, help="number of ddim sampling steps", ) parser.add_argument( "--scale", type=float, default=1.0, # if it's 1, only condition is taken into consideration help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "-r", "--resume", type=str, const=True, default="", nargs="?", help="resume from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", type=str, help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default="", ) return parser.parse_args() def main(): opt = parse_args() config = OmegaConf.load(opt.base) # print("-------quick debug no load ckpt---------") # model = instantiate_from_config(config['model'])# for quick debug model = load_model_from_config(config, opt.resume) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) os.makedirs(opt.outdir, exist_ok=True) vocoder = VocoderBigVGAN(config['lightning']['callbacks']['image_logger']['params']['vocoder_cfg']['params']['ckpt_vocoder'], device) if os.path.exists('/apdcephfs/share_1316500/nlphuang/data/video_to_audio/vggsound/split_txt'): root = '/apdcephfs' else: root = '/apdcephfs_intern' if opt.test_dataset == 'vggsound': split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/split_txt', f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/' dataset1_spec_dir = os.path.join(data, "mel_maa2", "npy") dataset1_feat_dir = os.path.join(data, "cavp") with open(os.path.join(split, 'vggsound_test.txt'), "r") as f: data_list1 = f.readlines() data_list1 = list(map(lambda x: x.strip(), data_list1)) spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, x) + "_mel.npy", data_list1)) video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x) + ".npz", data_list1)) # feat elif opt.test_dataset == 'landscape': split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/landscape/split/', f'{root}/share_1316500/nlphuang/data/video_to_audio/landscape/' dataset1_spec_dir = os.path.join(data, "melnone16000", "landscape_wav") dataset1_feat_dir = os.path.join(data, "landscape_visual_feat") # dataset1_feat_dir = os.path.join(data, "landscape_visual_feat_structured") # V1 cavp with open(os.path.join(split, 'test.txt'), "r") as f: data_list1 = f.readlines() data_list1 = list(map(lambda x: x.strip(), data_list1)) spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, 'test', x) + ".npy", data_list1)) # spec video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, 'test', x.replace('_mel', '')) + ".npy", data_list1)) # feat # video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, 'test', x.replace('_mel', '')) + "_new_fps_4.npy", data_list1)) # feat elif opt.test_dataset == 'Aist': split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/aist/split/', f'{root}/share_1316500/nlphuang/data/video_to_audio/aist/' dataset1_spec_dir = os.path.join(data, "melnone16000", "AIST++_crop_wav") dataset1_feat_dir = os.path.join(data, "AIST++_crop_visual_feat_V1") # V1 cavp with open(os.path.join(split, 'test.txt'), "r") as f: data_list1 = f.readlines() data_list1 = list(map(lambda x: x.strip(), data_list1)) spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, 'test', x) + ".npy", data_list1)) # spec # video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, 'test', x.replace('_mel', '')) + ".npy", data_list1)) video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x.split('/')[-1].replace('_mel', '')) + "_new_fps_4.npy",data_list1)) # feat elif opt.test_dataset == 'yt4m': split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/yt8m/split/', f'{root}/share_1316500/nlphuang/data/video_to_audio/yt8m/' dataset1_spec_dir = os.path.join(data, "melnone16000") dataset1_feat_dir = os.path.join(data, "cavp") with open(os.path.join(split, 'test.txt'), "r") as f: data_list1 = f.readlines() data_list1 = list(map(lambda x: x.strip(), data_list1)) spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, x) + "_mel.npy", data_list1)) # spec video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x) + ".npy",data_list1)) # feat else: raise NotImplementedError sr, duration, truncate, fps = opt.sample_rate, config['data']['params']['train']['params']['dataset_cfg']['duration']\ , config['data']['params']['train']['params']['dataset_cfg']['truncate'], config['data']['params']['train']['params']['dataset_cfg']['fps'] hop_len = config['data']['params']['train']['params']['dataset_cfg']['hop_len'] truncate_frame = int(fps * truncate / sr) if opt.scale != 1: unconditional = np.load(f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/cavp/empty_vid.npz')['feat'].astype(np.float32) feat_len = fps * duration if unconditional.shape[0] < feat_len: unconditional = np.tile(unconditional, (math.ceil(feat_len / unconditional.shape[0]), 1)) unconditional = unconditional[:int(feat_len)] unconditional = torch.from_numpy(unconditional).unsqueeze(0).to(device) unconditional = unconditional[:, :truncate_frame] uc = model.get_learned_conditioning(unconditional) # deal with long sequence shape = None if opt.length is not None: shape = (1, model.mel_dim, opt.length) from ldm.modules.diffusionmodules.flag_large_dit_moe import VideoFlagLargeDiT ntk_factor = opt.length // config['model']['params']['mel_length'] # if hasattr(model.model.diffusion_model, 'ntk_factor') and ntk_factor != model.model.diffusion_model.ntk_factor: max_len = config['model']['params']['unet_config']['params']['max_len'] print(f"override freqs_cis, ntk_factor {ntk_factor} max_len {max_len}", flush=True) model.model.diffusion_model.freqs_cis = VideoFlagLargeDiT.precompute_freqs_cis( config['model']['params']['unet_config']['params']['hidden_size'] // config['model']['params']['unet_config']['params']['num_heads'], config['model']['params']['unet_config']['params']['max_len'], ntk_factor=ntk_factor ) for i, (spec_path, video_feat_path) in enumerate(zip(spec_list1, video_list1)): name = Path(video_feat_path).stem if os.path.exists(os.path.join(opt.outdir, name + f'_0_gt.wav')): print(f'skip {name}') continue # waveform Features: try: spec_raw = np.load(spec_path).astype(np.float32) # channel: 1 except: print(f"corrupted mel: {spec_path}", flush=True) spec_raw = np.zeros((80, 625), dtype=np.float32) # [C, T] try: video_feat = np.load(video_feat_path)['feat'].astype(np.float32) except: video_feat = np.load(video_feat_path).astype(np.float32) spec_len = sr * duration / hop_len if spec_raw.shape[1] < spec_len: spec_raw = np.tile(spec_raw, math.ceil(spec_len / spec_raw.shape[1])) spec_raw = spec_raw[:, :int(spec_len)] feat_len = fps * duration if video_feat.shape[0] < feat_len: video_feat = np.tile(video_feat, (math.ceil(feat_len / video_feat.shape[0]), 1)) video_feat = video_feat[:int(feat_len)] spec_raw = torch.from_numpy(spec_raw).unsqueeze(0).to(device) video_feat = torch.from_numpy(video_feat).unsqueeze(0).to(device) feat_len = video_feat.shape[1] window_num = feat_len // truncate_frame gt_mel_list, mel_list = [], [] # [sample_list1, sample_list2, sample_list3 ....] for i in tqdm(range(window_num), desc="Window:"): start, end = i * truncate_frame, (i + 1) * truncate_frame spec_start, spec_end = int(start / fps * sr / hop_len), int(end / fps * sr / hop_len) c = model.get_learned_conditioning(video_feat[:, start:end]) if opt.scale == 1: # w/o cfg sample, _ = model.sample(c, 1, timesteps=opt.ddim_steps, shape=shape) else: # cfg sample, _ = model.sample_cfg(c, opt.scale, uc, 1, timesteps=opt.ddim_steps, shape=shape) x_samples_ddim = model.decode_first_stage(sample) mel_list.append(x_samples_ddim) gt_mel_list.append(spec_raw[:, spec_start: spec_end]) if len(mel_list) > 0: syn_mel = np.concatenate([mel.cpu() for mel in mel_list], 1) if len(gt_mel_list) > 0: gt_mel = np.concatenate([mel.cpu() for mel in gt_mel_list], 1) for idx, (spec, x_samples_ddim) in enumerate(zip(gt_mel, syn_mel)): wav = vocoder.vocode(spec) wav_path = os.path.join(opt.outdir, name + f'_{idx}_gt.wav') soundfile.write(wav_path, wav, opt.sample_rate) ddim_wav = vocoder.vocode(x_samples_ddim) wav_path = os.path.join(opt.outdir, name + f'_{idx}.wav') soundfile.write(wav_path, ddim_wav, opt.sample_rate) print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.") if __name__ == "__main__": main()