In [1]:
import sys

sys.path.append('../../')

In [3]:
import glob
import json
import os
from pathlib import Path

import librosa
import torch
from academicodec.models.hificodec.vqvae import VQVAE
from librosa.util import normalize
from tqdm import tqdm

ckpt_path = './checkpoint/HiFi-Codec-24k-320d'
config_path = './config_24k_320d.json'
with open(config_path, 'r') as f:
    config = json.load(f)
    sample_rate = config['sampling_rate']

outputdir = './output'
inputdir = './test_wav'
num = 1024

if __name__ == '__main__':
    Path(outputdir).mkdir(parents=True, exist_ok=True)
    print("Init model and load weights")
    # make sure you downloaded the weights from https://huggingface.co/Dongchao/AcademiCodec/blob/main/HiFi-Codec-24k-320d 
    # and put it in ./checkpoint/
    model = VQVAE(
        config_path,
        ckpt_path,
        with_encoder=True)
    model.cuda()
    model.eval()
    print("Model ready")

    wav_paths = glob.glob(f"{inputdir}/*.wav")[:num]
    print(f"Globbed {len(wav_paths)} wav files.")
    fid_to_acoustic_token = {}
    for wav_path in tqdm(wav_paths[:1]):
        wav, sr = librosa.load(wav_path, sr=sample_rate)
        print("wav.shape:",wav.shape)
        assert sr == sample_rate
        fid = os.path.basename(wav_path)[:-4]
        wav = normalize(wav) * 0.95
        wav = torch.FloatTensor(wav).unsqueeze(0)
        wav = wav.to(torch.device('cuda'))
        acoustic_token = model.encode(wav)
        print("acoustic_token:",acoustic_token)
        print("acoustic_token.shape:",acoustic_token.shape)
        print("acoustic_token.dtype:",acoustic_token.dtype)
        fid = os.path.basename(wav_path)[:-4]
        fid_to_acoustic_token[fid] = acoustic_token

    torch.save(fid_to_acoustic_token,
               os.path.join(outputdir, 'fid_to_acoustic_token.pth'))


Init model and load weights
Model ready
Globbed 12 wav files.


100%|███████████| 1/1 [00:00<00:00, 11.08it/s]

wav.shape: (97681,)
acoustic_token: tensor([[[ 11, 591, 281, 629],
         [733, 591, 401, 139],
         [500, 591, 733, 600],
         ...,
         [733, 591, 451, 346],
         [733, 591, 401, 139],
         [386, 591, 281, 461]]], device='cuda:0')
acoustic_token.shape: torch.Size([1, 305, 4])
acoustic_token.dtype: torch.int64



