Spaces:
Runtime error
Runtime error
import os | |
import io | |
import glob | |
import math | |
import tarfile | |
import torch | |
import torchaudio | |
import safetensors | |
from .configuration_whisper import WhisperVQConfig | |
from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration | |
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast | |
def load_quantize_encoder(model_path): | |
config = WhisperVQConfig.from_pretrained(model_path) | |
config.quantize_encoder_only = True | |
model = WhisperVQEncoder(config) | |
state_dict = {} | |
for path in glob.glob(os.path.join(model_path, "model*.safetensors")): | |
with safetensors.safe_open(path, framework="pt", device="cpu") as f: | |
for key in f.keys(): | |
if key.startswith("model.encoder."): | |
new_key = key[len("model.encoder."):] | |
if new_key.startswith("layer_norm"): | |
continue | |
if new_key.startswith("layers"): | |
layer_id = int(new_key.split(".")[1]) | |
if layer_id >= config.quantize_position: | |
continue | |
state_dict[new_key] = f.get_tensor(key) | |
model.load_state_dict(state_dict) | |
model.eval() | |
model.cuda() | |
return model | |
_resample_buffer: dict[int, torchaudio.transforms.Resample] = {} | |
def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts): | |
with torch.no_grad(): | |
audios, indices = [], [] | |
for idx, utt in enumerate(utts): | |
if isinstance(utt, tuple): | |
audio, sample_rate = utt | |
else: | |
audio, sample_rate = torchaudio.load(utt) | |
audio = audio.cuda() | |
if sample_rate != 16000: | |
if sample_rate not in _resample_buffer: | |
_resample_buffer[sample_rate] = torchaudio.transforms.Resample( | |
orig_freq=sample_rate, | |
new_freq=16000 | |
).to('cuda') | |
audio = _resample_buffer[sample_rate](audio) | |
# if audio.shape[0] > 1: | |
# audio = audio[:1] | |
audio = audio[0] | |
audio = audio.cpu().numpy() | |
time_step = 0 | |
while time_step * 16000 < audio.shape[0]: | |
audio_segment = audio[time_step * 16000: (time_step + 30) * 16000] | |
audios.append(audio_segment) | |
indices.append(idx) | |
time_step += 30 | |
pooling_kernel_size = model.config.pooling_kernel_size or 1 | |
stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length | |
all_speech_tokens = [[] for _ in range(len(utts))] | |
batch_size = 128 | |
for start in range(0, len(audios), batch_size): | |
features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000, | |
return_attention_mask=True, return_tensors="pt", device='cuda', | |
padding="longest", pad_to_multiple_of=stride) | |
features = features.to(device="cuda") | |
outputs = model(**features) | |
speech_tokens = outputs.quantized_token_ids | |
attention_mask = features.attention_mask[:, ::model.conv1.stride[0] * model.conv2.stride[0]] | |
attention_mask = attention_mask[:, ::model.config.pooling_kernel_size] | |
assert attention_mask.shape == speech_tokens.shape | |
for i in range(len(speech_tokens)): | |
idx = indices[start + i] | |
speech_token = speech_tokens[i][attention_mask[i].bool()].tolist() | |
all_speech_tokens[idx].extend(speech_token) | |
return all_speech_tokens | |