import os import io import glob import math import tarfile import torch import torchaudio import safetensors from .configuration_whisper import WhisperVQConfig from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration from transformers import WhisperFeatureExtractor, WhisperTokenizerFast def load_quantize_encoder(model_path): config = WhisperVQConfig.from_pretrained(model_path) config.quantize_encoder_only = True model = WhisperVQEncoder(config) state_dict = {} for path in glob.glob(os.path.join(model_path, "model*.safetensors")): with safetensors.safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): if key.startswith("model.encoder."): new_key = key[len("model.encoder."):] if new_key.startswith("layer_norm"): continue if new_key.startswith("layers"): layer_id = int(new_key.split(".")[1]) if layer_id >= config.quantize_position: continue state_dict[new_key] = f.get_tensor(key) model.load_state_dict(state_dict) model.eval() model.cuda() return model _resample_buffer: dict[int, torchaudio.transforms.Resample] = {} def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts): with torch.no_grad(): audios, indices = [], [] for idx, utt in enumerate(utts): if isinstance(utt, tuple): audio, sample_rate = utt else: audio, sample_rate = torchaudio.load(utt) audio = audio.cuda() if sample_rate != 16000: if sample_rate not in _resample_buffer: _resample_buffer[sample_rate] = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=16000 ).to('cuda') audio = _resample_buffer[sample_rate](audio) # if audio.shape[0] > 1: # audio = audio[:1] audio = audio[0] audio = audio.cpu().numpy() time_step = 0 while time_step * 16000 < audio.shape[0]: audio_segment = audio[time_step * 16000: (time_step + 30) * 16000] audios.append(audio_segment) indices.append(idx) time_step += 30 pooling_kernel_size = model.config.pooling_kernel_size or 1 stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length all_speech_tokens = [[] for _ in range(len(utts))] batch_size = 128 for start in range(0, len(audios), batch_size): features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000, return_attention_mask=True, return_tensors="pt", device='cuda', padding="longest", pad_to_multiple_of=stride) features = features.to(device="cuda") outputs = model(**features) speech_tokens = outputs.quantized_token_ids attention_mask = features.attention_mask[:, ::model.conv1.stride[0] * model.conv2.stride[0]] attention_mask = attention_mask[:, ::model.config.pooling_kernel_size] assert attention_mask.shape == speech_tokens.shape for i in range(len(speech_tokens)): idx = indices[start + i] speech_token = speech_tokens[i][attention_mask[i].bool()].tolist() all_speech_tokens[idx].extend(speech_token) return all_speech_tokens