Make-An-Audio-3 / scripts /video2audio_flow_inpaint.py
3v324v23's picture
Add code
a84a65c
raw
history blame
13.8 kB
import argparse, os, sys, glob
import pathlib
directory = pathlib.Path(os.getcwd())
print(directory)
sys.path.append(str(directory))
import torch
import numpy as np
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.cfm1_audio_sampler import CFMSampler
import random, math, librosa
from vocoder.bigvgan.models import VocoderBigVGAN
import soundfile
from pathlib import Path
from tqdm import tqdm
def load_model_from_config(config, ckpt = None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt:
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
print(f'---------------------------epoch : {pl_sd["epoch"]}, global step: {pl_sd["global_step"]//1e3}k---------------------------')
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
print(f"Note chat no ckpt is loaded !!!")
model.cuda()
model.eval()
return model
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--sample_rate",
type=int,
default="16000",
help="sample rate of wav"
)
parser.add_argument(
"--length",
type=int,
default=None,
help="length of wav"
)
parser.add_argument(
"--test-dataset",
default="vggsound",
help="test which dataset: vggsound/landscape/fsd50k"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2audio-samples"
)
parser.add_argument(
"--ddim_steps",
type=int,
default=25,
help="number of ddim sampling steps",
)
parser.add_argument(
"--n_samples",
type=int,
default=1,
help="how many samples to produce for the given prompt",
)
parser.add_argument(
"--scale",
type=float,
default=1.0, # if it's 1, only condition is taken into consideration
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"-r",
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
type=str,
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default="",
)
return parser.parse_args()
def get_tail_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device):
# apply mask
masked_spec = int(spec_truncate * 0.5) # 16帧的倍数,最多mask 50%
masked_truncate = int(masked_spec * hop_len)
masked_frame = int(fps * masked_truncate / sr)
start_masked_idx = truncate * 0.25
start_masked_frame = int(fps * start_masked_idx / sr)
start_masked_spec = int(start_masked_idx / hop_len)
spec = gt_mel.copy()
spec[:, start_masked_spec:start_masked_spec + masked_spec] = torch.zeros((80, masked_spec))
gt_video_feat[start_masked_frame:start_masked_frame + masked_frame, :] = np.zeros((masked_frame, 512))
spec = torch.from_numpy(spec).unsqueeze(0).to(device)
gt_video_feat = torch.from_numpy(gt_video_feat).unsqueeze(0).to(device)
return spec, gt_video_feat
def get_random_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device):
# apply mask
masked_spec = random.randint(1, int(spec_truncate * 0.5 // 16)) * 16 # 16帧的倍数,最多mask 50%
masked_truncate = int(masked_spec * hop_len)
masked_frame = int(fps * masked_truncate / sr)
start_masked_idx = random.randint(0, truncate - masked_truncate - 1)
start_masked_frame = int(fps * start_masked_idx / sr)
start_masked_spec = int(start_masked_idx / hop_len)
spec = gt_mel.copy()
spec[:, start_masked_spec:start_masked_spec + masked_spec] = torch.zeros((80, masked_spec))
gt_video_feat[start_masked_frame:start_masked_frame + masked_frame, :] = np.zeros((masked_frame, 512))
spec = torch.from_numpy(spec).unsqueeze(0).to(device)
gt_video_feat = torch.from_numpy(gt_video_feat).unsqueeze(0).to(device)
return spec, gt_video_feat
def main():
opt = parse_args()
config = OmegaConf.load(opt.base)
# print("-------quick debug no load ckpt---------")
# model = instantiate_from_config(config['model'])# for quick debug
model = load_model_from_config(config, opt.resume)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = CFMSampler(model, opt.ddim_steps)
os.makedirs(opt.outdir, exist_ok=True)
vocoder = VocoderBigVGAN(config['lightning']['callbacks']['image_logger']['params']['vocoder_cfg']['params']['ckpt_vocoder'], device)
if os.path.exists('/apdcephfs/share_1316500/nlphuang/data/video_to_audio/vggsound/split_txt'):
root = '/apdcephfs'
else:
root = '/apdcephfs_intern'
if opt.test_dataset == 'vggsound':
split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/split_txt', f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/'
dataset1_spec_dir = os.path.join(data, "mel_maa2", "npy")
dataset1_feat_dir = os.path.join(data, "cavp")
with open(os.path.join(split, 'vggsound_test.txt'), "r") as f:
data_list1 = f.readlines()
data_list1 = list(map(lambda x: x.strip(), data_list1))
spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, x) + "_mel.npy", data_list1))
video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, x) + ".npz", data_list1)) # feat
elif opt.test_dataset == 'landscape':
split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/landscape/split/', f'{root}/share_1316500/nlphuang/data/video_to_audio/landscape/'
dataset1_spec_dir = os.path.join(data, "melnone16000", "landscape_wav")
dataset1_feat_dir = os.path.join(data, "landscape_visual_feat")
with open(os.path.join(split, 'test.txt'), "r") as f:
data_list1 = f.readlines()
data_list1 = list(map(lambda x: x.strip(), data_list1))
spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, 'test', x) + ".npy", data_list1)) # spec
video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, 'test', x.replace('_mel', '')) + ".npy", data_list1)) # feat
elif opt.test_dataset == 'Aist':
split, data = f'{root}/share_1316500/nlphuang/data/video_to_audio/aist/split/', f'{root}/share_1316500/nlphuang/data/video_to_audio/aist/'
dataset1_spec_dir = os.path.join(data, "melnone16000", "AIST++_crop_wav")
dataset1_feat_dir = os.path.join(data, "AIST++_crop_visual_feat")
with open(os.path.join(split, 'test.txt'), "r") as f:
data_list1 = f.readlines()
data_list1 = list(map(lambda x: x.strip(), data_list1))
spec_list1 = list(map(lambda x: os.path.join(dataset1_spec_dir, 'test', x) + ".npy", data_list1)) # spec
video_list1 = list(map(lambda x: os.path.join(dataset1_feat_dir, 'test', x.replace('_mel', '')) + ".npy", data_list1))
else:
raise NotImplementedError
sr, duration, truncate, fps = opt.sample_rate, config['data']['params']['train']['params']['dataset_cfg']['duration']\
, config['data']['params']['train']['params']['dataset_cfg']['truncate'], config['data']['params']['train']['params']['dataset_cfg']['fps']
hop_len = config['data']['params']['train']['params']['dataset_cfg']['hop_len']
truncate_frame = int(fps * truncate / sr)
spec_truncate = int(truncate / hop_len)
if opt.scale != 1:
unconditional = np.load(f'{root}/share_1316500/nlphuang/data/video_to_audio/vggsound/cavp/empty_vid.npz')['feat'].astype(np.float32)
feat_len = fps * duration
if unconditional.shape[0] < feat_len:
unconditional = np.tile(unconditional, (math.ceil(feat_len / unconditional.shape[0]), 1))
unconditional = unconditional[:int(feat_len)]
unconditional = torch.from_numpy(unconditional).unsqueeze(0).to(device)
unconditional = unconditional[:, :truncate_frame]
# deal with long sequence
shape = None
if opt.length is not None:
shape = (1, model.mel_dim, opt.length)
from ldm.modules.diffusionmodules.flag_large_dit_moe import VideoFlagLargeDiT
ntk_factor = opt.length // config['model']['params']['mel_length']
# if hasattr(model.model.diffusion_model, 'ntk_factor') and ntk_factor != model.model.diffusion_model.ntk_factor:
print(f"override freqs_cis, ntk_factor {ntk_factor}, flush=True")
model.model.diffusion_model.freqs_cis = VideoFlagLargeDiT.precompute_freqs_cis(
config['model']['params']['unet_config']['params']['hidden_size'] //
config['model']['params']['unet_config']['params']['num_heads'],
config['model']['params']['unet_config']['params']['max_len'],
ntk_factor=ntk_factor
)
for i, (spec_path, video_feat_path) in enumerate(zip(spec_list1, video_list1)):
name = Path(video_feat_path).stem
if os.path.exists(os.path.join(opt.outdir, name + f'_0_gt.wav')):
print(f'skip {name}')
continue
# waveform Features:
try:
spec_raw = np.load(spec_path).astype(np.float32) # channel: 1
except:
print(f"corrupted mel: {spec_path}", flush=True)
spec_raw = np.zeros((80, 625), dtype=np.float32) # [C, T]
try:
video_feat = np.load(video_feat_path)['feat'].astype(np.float32)
except:
video_feat = np.load(video_feat_path).astype(np.float32)
spec_len = sr * duration / hop_len
if spec_raw.shape[1] < spec_len:
spec_raw = np.tile(spec_raw, math.ceil(spec_len / spec_raw.shape[1]))
spec_raw = spec_raw[:, :int(spec_len)]
feat_len = fps * duration
if video_feat.shape[0] < feat_len:
video_feat = np.tile(video_feat, (math.ceil(feat_len / video_feat.shape[0]), 1))
video_feat = video_feat[:int(feat_len)]
window_num = video_feat.shape[0] // truncate_frame
gt_mel_list, mel_list, masked_mel_list = [], [], [] # [sample_list1, sample_list2, sample_list3 ....]
for i in tqdm(range(window_num), desc="Window:"):
start, end = i * truncate_frame, (i + 1) * truncate_frame
spec_start = int(start / fps * sr / hop_len)
gt_video_feat = video_feat[start:end]
gt_mel = spec_raw[:, spec_start: spec_start + spec_truncate]
# apply mask
# spec, gt_video_feat = get_random_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device)
spec, gt_video_feat = get_tail_mask(spec_truncate, gt_mel, gt_video_feat, fps, sr, truncate, hop_len, device)
# start sampling
encoder_posterior = model.encode_first_stage(spec)
z_spec = model.get_first_stage_encoding(encoder_posterior).detach()
c = model.get_learned_conditioning({'mix_video_feat': gt_video_feat, 'mix_spec': z_spec})
if opt.scale == 1: # w/o cfg
sample, _ = sampler.sample(c, 1, timesteps=opt.ddim_steps, shape=shape)
else: # cfg
uc = model.get_learned_conditioning({'mix_video_feat': unconditional, 'mix_spec': z_spec})
sample, _ = sampler.sample_cfg(c, opt.scale, uc, 1, timesteps=opt.ddim_steps, shape=shape)
x_samples_ddim = model.decode_first_stage(sample)
mel_list.append(x_samples_ddim)
masked_mel_list.append(spec)
gt_mel_list.append(torch.from_numpy(gt_mel).unsqueeze(0))
if len(mel_list) > 0:
syn_mel = np.concatenate([mel.cpu() for mel in mel_list], 1)
if len(gt_mel_list) > 0:
gt_mel = np.concatenate([mel for mel in gt_mel_list], 1)
if len(gt_mel_list) > 0:
masked_mel = np.concatenate([mel.cpu() for mel in masked_mel_list], 1)
for idx, (spec, x_samples_ddim, spec_masked) in enumerate(zip(gt_mel, syn_mel, masked_mel)):
wav = vocoder.vocode(spec)
wav_path = os.path.join(opt.outdir, name + f'_{idx}_gt.wav')
soundfile.write(wav_path, wav, opt.sample_rate)
wav = vocoder.vocode(spec_masked)
wav_path = os.path.join(opt.outdir, name + f'_{idx}_mask.wav')
soundfile.write(wav_path, wav, opt.sample_rate)
ddim_wav = vocoder.vocode(x_samples_ddim)
wav_path = os.path.join(opt.outdir, name + f'_{idx}.wav')
soundfile.write(wav_path, ddim_wav, opt.sample_rate)
print(f"Your samples are ready and waiting four you here: \n{opt.outdir} \nEnjoy.")
if __name__ == "__main__":
main()