How to get the discrete codes correctly
I am trying to get the discrete codes in the right way but seems the faiss index is wrong somehow ?.
import torch
from transformers import HubertModel
from datasets import load_dataset
import faiss
import numpy as np
def load_index(index_path):
index: faiss.IndexPreTransform = faiss.read_index(index_path)
#Make sure we have access to the ivf subindex. We'll need it to get the centroids (clusters)
index_ivf = faiss.extract_index_ivf(index)
return index, index_ivf
def get_centroids_index(xq, index, index_ivf):
''' Get centroids '''
#Get OPQ matix
opq_mt = faiss.downcast_VectorTransform(index.chain.at(0))
#Apply pre-transform to query
xq_t = opq_mt.apply_py(xq)
#Get centroids C and distances DC on a pre-transformed index
DC,C = index_ivf.quantizer.search(xq_t, 1)
return DC, C
class Hubert2Unit(torch.nn.Module):
def __init__(
self,
model_name="",
kmean_path="",
dtype=torch.float32,
device="cuda:0",
):
super(Hubert2Unit, self).__init__()
self.model = HubertModel.from_pretrained("utter-project/mHuBERT-147").eval()
self.model.to(dtype=torch.float32, device=device) # trained with float32
self.index, self.index_ivf = load_index("mhubert147_faiss.index")
def zero_mean_unit_var_norm(
self, input_values, wav_lengths, padding_value: float = 0.0
):
"""
Every array in the list is normalized to have zero mean and unit variance
"""
if wav_lengths is not None:
normed_input_values = []
for vector, length in zip(input_values, wav_lengths):
normed_slice = (vector - vector[:length].mean()) / torch.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = padding_value
normed_input_values.append(normed_slice)
else:
normed_input_values = [(x - x.mean()) / torch.sqrt(x.var() + 1e-7) for x in input_values]
return torch.stack(normed_input_values, dim=0)
def forward(self, wav, wav_lengths, do_normalize=True):
with torch.no_grad():
if do_normalize:
input_values = self.zero_mean_unit_var_norm(wav, wav_lengths)
else:
input_values = wav.clone()
# calcualte the attention_mask based on the wav_lengths_16k
attention_mask = torch.arange(
input_values.size(1),
device=input_values.device)[None, :] < wav_lengths[:, None]
attention_mask = attention_mask.long()
hidden_states = self.model(
input_values,
attention_mask=attention_mask,
output_hidden_states=True
).hidden_states[9] # 9th layer of encoder block.
hidden_states = hidden_states.reshape(hidden_states.size(0) * hidden_states.size(1), -1)
hidden_states_cpu = hidden_states.float().detach().cpu().numpy()
_, C = get_centroids_index(hidden_states_cpu, self.index, self.index_ivf)
C = C.reshape(wav.shape[0], -1)
n_unique_codes = len(np.unique(C))
return C, n_unique_codes
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True)
hubert = Hubert2Unit()
wav = ds[0]["audio"]["array"]
wav = torch.tensor(wav).to("cuda:0").unsqueeze(0).float()
lengths = torch.tensor([wav.shape[1]]).to("cuda:0")
C, n = hubert(wav, lengths)
@mzboito Thanks you in advance.
Hi! Thanks for moving this thread to a dedicated issue.
The source of your issue is very likely the mismatch between the trained faiss index and the mhubert-147 model you are using.
This index here (https://huggingface.co/utter-project/mHuBERT-147/blob/main/mhubert147_faiss.index) was trained on the output of the 9th layer of the 2nd iteration mHuBERT-147 (https://huggingface.co/utter-project/mHuBERT-147-base-2nd-iter), in order to generate targets for the mHuBERT-147 3rd iteration training.
If you input mHuBERT-147 (3rd iteration) features into it, it will not know how to cluster it very well, as it was trained on the output of a different model.
Basically, there are two settings in which you might be interested on faiss:
If you want to continuous pretrain the mHuBERT-147 (3rd iteration), you should extract features for your speech using the 2nd iteration 9th layer, and then generate the indices using the faiss index you are using (https://huggingface.co/utter-project/mHuBERT-147/blob/main/mhubert147_faiss.index). This should work.
If you want to generate faiss discretization using as input the features from the 3rd iteration (mHuBERT-147), then you need to train a new index on your target data. You can check our training recommendations here: https://github.com/utter-project/mHuBERT-147-scripts
I hope it was understandable!
@mzboito thanks you so much, it seems to be corrected now. Just want to make sure everything is matching, the do_normalize = True and the hidden_states[9] are correct (instead of False or hidden_states[10]) ?. The reason is because it seems true that len(hidden_states) = 13 not 12.
Yes, do_normalize=True for everything.
Regarding the layer: I did feature extraction on fairseq, not HF, so I'm not 100% sure, but it should be [9] if your length is of 13.
That is because the forward for feature extraction takes output_layer - 1: https://github.com/utter-project/fairseq/blob/3fb951a8658b81f09011fc2e9e5fe4c2e818a304/fairseq/models/hubert/hubert.py#L470
@dathudeptrai thanks for the snippet, I traced the model together with clustering into a jit: https://huggingface.co/balacoon/mhubert, https://balacoon.com/blog/mhubert_tracing/.
@mzboito : can you confirm that attention_mask - does not have an effect? also did you train in full precision, any suggestion why fp16 might be failing on longer sequences?
Hi @clementruhm , thanks for creating this neat extractor. :) Any chance you can rename it and add the "147" on the title. This is to avoid confusion with this model: https://huggingface.co/voidful/mhubert-base
- Attention mask should not have an impact during inference
- utter-project/mHuBERT-147-base-2nd-iter is a fp16 model, differently from the final version (utter-project/mHuBERT-147).
I am surprised to hear that it is failing on longer sequences. What exactly is the error you are getting?
Updated the name. My bad re fp16, turned out audio normalization was overflowing. Fixed and uploaded fp16 version