import argparse, os, sys, glob import pathlib directory = pathlib.Path(os.getcwd()) print(directory) sys.path.append(str(directory)) import torch import numpy as np from omegaconf import OmegaConf from ldm.util import instantiate_from_config from ldm.models.diffusion.cfm1_audio_sampler import CFMSampler import random, math, librosa from vocoder.bigvgan.models import VocoderBigVGAN import soundfile from pathlib import Path from tqdm import tqdm def load_model_from_config(config, ckpt = None, verbose=True): model = instantiate_from_config(config.model) if ckpt: print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") sd = pl_sd["state_dict"] print(f'---------------------------epoch : {pl_sd["epoch"]}, global step: {pl_sd["global_step"]//1e3}k---------------------------') m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) else: print(f"Note chat no ckpt is loaded !!!") model.cuda() model.eval() return model def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--sample_rate", type=int, default="16000", help="sample rate of wav" ) parser.add_argument( "--length", type=int, default=None, help="length of wav" ) parser.add_argument( "--test-dataset", default="vggsound", help="test which dataset: vggsound/landscape/fsd50k" ) parser.add_argument( "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2audio-samples" ) parser.add_argument( "--ddim_steps", type=int, default=25, help="number of ddim sampling steps", ) parser.add_argument( "--n_samples", type=int, default=1, help="how many samples to produce for the given prompt", ) parser.add_argument( "--scale", type=float, default=1.0, # if it's 1, only condition is taken into consideration help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", ) parser.add_argument( "-r", "--resume", type=str, const=True, default="", nargs="?", help="resume from logdir or checkpoint in logdir", ) parser.add_argument( "-b", "--base", type=str, help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default="", ) return parser.parse_args() def get_tail_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device): # apply mask masked_spec = int(spec_truncate * 0.5) # 16帧的倍数,最多mask 50% masked_truncate = int(masked_spec * hop_len) masked_frame = int(fps * masked_truncate / sr) start_masked_idx = truncate * 0.25 start_masked_frame = int(fps * start_masked_idx / sr) start_masked_spec = int(start_masked_idx / hop_len) spec = gt_mel.copy() spec[:, start_masked_spec:start_masked_spec + masked_spec] = torch.zeros((80, masked_spec)) gt_video_feat[start_masked_frame:start_masked_frame + masked_frame, :] = np.zeros((masked_frame, 512)) spec = torch.from_numpy(spec).unsqueeze(0).to(device) gt_video_feat = torch.from_numpy(gt_video_feat).unsqueeze(0).to(device) return spec, gt_video_feat def get_random_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device): # apply mask masked_spec = random.randint(1, int(spec_truncate * 0.5 // 16)) * 16 # 16帧的倍数,最多mask 50% masked_truncate = int(masked_spec * hop_len) masked_frame = int(fps * masked_truncate / sr) start_masked_idx = random.randint(0, truncate - masked_truncate - 1) start_masked_frame = int(fps * start_masked_idx / sr) start_masked_spec = int(start_masked_idx / hop_len) spec = gt_mel.copy() spec[:, start_masked_spec:start_masked_spec + masked_spec] = torch.zeros((80, masked_spec)) gt_video_feat[start_masked_frame:start_masked_frame + masked_frame, :] = np.zeros((masked_frame, 512)) spec = torch.from_numpy(spec).unsqueeze(0).to(device) gt_video_feat = torch.from_numpy(gt_video_feat).unsqueeze(0).to(device) return spec, gt_video_feat def main(): opt = parse_args() config = OmegaConf.load(opt.base) # print("-------quick debug no load ckpt---------") # model = instantiate_from_config(config['model'])# for quick debug model = load_model_from_config(config, opt.resume) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) sampler = CFMSampler(model, opt.ddim_steps) os.makedirs(opt.outdir, exist_ok=True) vocoder = VocoderBigVGAN(config['lightning']['callbacks']['image_logger']['params']['vocoder_cfg']['params']['ckpt_vocoder'], device) if os.path.exists('/apdcephfs/share_1316500/nlphuang/data/video_to_audio/vggsound/split_txt'): root = '/apdcephfs' else: root = '/apdcephfs_intern' if opt.test_dataset == 'vggsound': split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/split_txt', f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/' dataset1_spec_dir = os.path.join(data, "mel_maa2", "npy") dataset1_feat_dir = os.path.join(data, "cavp") with open(os.path.join(split, 'vggsound_test.txt'), "r") as f: data_list1 = f.readlines() data_list1 = list(map(lambda x: x.strip(), data_list1)) spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, x) + "_mel.npy", data_list1)) video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x) + ".npz", data_list1)) # feat elif opt.test_dataset == 'landscape': split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/landscape/split/', f'{root}/share_1316500/nlphuang/data/video_to_audio/landscape/' dataset1_spec_dir = os.path.join(data, "melnone16000", "landscape_wav") dataset1_feat_dir = os.path.join(data, "landscape_visual_feat") with open(os.path.join(split, 'test.txt'), "r") as f: data_list1 = f.readlines() data_list1 = list(map(lambda x: x.strip(), data_list1)) spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, 'test', x) + ".npy", data_list1)) # spec video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, 'test', x.replace('_mel', '')) + ".npy", data_list1)) # feat elif opt.test_dataset == 'Aist': split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/aist/split/', f'{root}/share_1316500/nlphuang/data/video_to_audio/aist/' dataset1_spec_dir = os.path.join(data, "melnone16000", "AIST++_crop_wav") dataset1_feat_dir = os.path.join(data, "AIST++_crop_visual_feat") with open(os.path.join(split, 'test.txt'), "r") as f: data_list1 = f.readlines() data_list1 = list(map(lambda x: x.strip(), data_list1)) spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, 'test', x) + ".npy", data_list1)) # spec video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, 'test', x.replace('_mel', '')) + ".npy", data_list1)) else: raise NotImplementedError sr, duration, truncate, fps = opt.sample_rate, config['data']['params']['train']['params']['dataset_cfg']['duration']\ , config['data']['params']['train']['params']['dataset_cfg']['truncate'], config['data']['params']['train']['params']['dataset_cfg']['fps'] hop_len = config['data']['params']['train']['params']['dataset_cfg']['hop_len'] truncate_frame = int(fps * truncate / sr) spec_truncate = int(truncate / hop_len) if opt.scale != 1: unconditional = np.load(f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/cavp/empty_vid.npz')['feat'].astype(np.float32) feat_len = fps * duration if unconditional.shape[0] < feat_len: unconditional = np.tile(unconditional, (math.ceil(feat_len / unconditional.shape[0]), 1)) unconditional = unconditional[:int(feat_len)] unconditional = torch.from_numpy(unconditional).unsqueeze(0).to(device) unconditional = unconditional[:, :truncate_frame] # deal with long sequence shape = None if opt.length is not None: shape = (1, model.mel_dim, opt.length) from ldm.modules.diffusionmodules.flag_large_dit_moe import VideoFlagLargeDiT ntk_factor = opt.length // config['model']['params']['mel_length'] # if hasattr(model.model.diffusion_model, 'ntk_factor') and ntk_factor != model.model.diffusion_model.ntk_factor: print(f"override freqs_cis, ntk_factor {ntk_factor}, flush=True") model.model.diffusion_model.freqs_cis = VideoFlagLargeDiT.precompute_freqs_cis( config['model']['params']['unet_config']['params']['hidden_size'] // config['model']['params']['unet_config']['params']['num_heads'], config['model']['params']['unet_config']['params']['max_len'], ntk_factor=ntk_factor ) for i, (spec_path, video_feat_path) in enumerate(zip(spec_list1, video_list1)): name = Path(video_feat_path).stem if os.path.exists(os.path.join(opt.outdir, name + f'_0_gt.wav')): print(f'skip {name}') continue # waveform Features: try: spec_raw = np.load(spec_path).astype(np.float32) # channel: 1 except: print(f"corrupted mel: {spec_path}", flush=True) spec_raw = np.zeros((80, 625), dtype=np.float32) # [C, T] try: video_feat = np.load(video_feat_path)['feat'].astype(np.float32) except: video_feat = np.load(video_feat_path).astype(np.float32) spec_len = sr * duration / hop_len if spec_raw.shape[1] < spec_len: spec_raw = np.tile(spec_raw, math.ceil(spec_len / spec_raw.shape[1])) spec_raw = spec_raw[:, :int(spec_len)] feat_len = fps * duration if video_feat.shape[0] < feat_len: video_feat = np.tile(video_feat, (math.ceil(feat_len / video_feat.shape[0]), 1)) video_feat = video_feat[:int(feat_len)] window_num = video_feat.shape[0] // truncate_frame gt_mel_list, mel_list, masked_mel_list = [], [], [] # [sample_list1, sample_list2, sample_list3 ....] for i in tqdm(range(window_num), desc="Window:"): start, end = i * truncate_frame, (i + 1) * truncate_frame spec_start = int(start / fps * sr / hop_len) gt_video_feat = video_feat[start:end] gt_mel = spec_raw[:, spec_start: spec_start + spec_truncate] # apply mask # spec, gt_video_feat = get_random_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device) spec, gt_video_feat = get_tail_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device) # start sampling encoder_posterior = model.encode_first_stage(spec) z_spec = model.get_first_stage_encoding(encoder_posterior).detach() c = model.get_learned_conditioning({'mix_video_feat': gt_video_feat, 'mix_spec': z_spec}) if opt.scale == 1: # w/o cfg sample, _ = sampler.sample(c, 1, timesteps=opt.ddim_steps, shape=shape) else: # cfg uc = model.get_learned_conditioning({'mix_video_feat': unconditional, 'mix_spec': z_spec}) sample, _ = sampler.sample_cfg(c, opt.scale, uc, 1, timesteps=opt.ddim_steps, shape=shape) x_samples_ddim = model.decode_first_stage(sample) mel_list.append(x_samples_ddim) masked_mel_list.append(spec) gt_mel_list.append(torch.from_numpy(gt_mel).unsqueeze(0)) if len(mel_list) > 0: syn_mel = np.concatenate([mel.cpu() for mel in mel_list], 1) if len(gt_mel_list) > 0: gt_mel = np.concatenate([mel for mel in gt_mel_list], 1) if len(gt_mel_list) > 0: masked_mel = np.concatenate([mel.cpu() for mel in masked_mel_list], 1) for idx, (spec, x_samples_ddim, spec_masked) in enumerate(zip(gt_mel, syn_mel, masked_mel)): wav = vocoder.vocode(spec) wav_path = os.path.join(opt.outdir, name + f'_{idx}_gt.wav') soundfile.write(wav_path, wav, opt.sample_rate) wav = vocoder.vocode(spec_masked) wav_path = os.path.join(opt.outdir, name + f'_{idx}_mask.wav') soundfile.write(wav_path, wav, opt.sample_rate) ddim_wav = vocoder.vocode(x_samples_ddim) wav_path = os.path.join(opt.outdir, name + f'_{idx}.wav') soundfile.write(wav_path, ddim_wav, opt.sample_rate) print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.") if __name__ == "__main__": main()