|
import os |
|
import io |
|
import glob |
|
import math |
|
import tarfile |
|
import torch |
|
import torchaudio |
|
import safetensors |
|
from .configuration_whisper import WhisperVQConfig |
|
from .modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration |
|
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast |
|
|
|
|
|
def load_quantize_encoder(model_path): |
|
config = WhisperVQConfig.from_pretrained(model_path) |
|
config.quantize_encoder_only = True |
|
model = WhisperVQEncoder(config) |
|
state_dict = {} |
|
for path in glob.glob(os.path.join(model_path, "model*.safetensors")): |
|
with safetensors.safe_open(path, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
if key.startswith("model.encoder."): |
|
new_key = key[len("model.encoder."):] |
|
if new_key.startswith("layer_norm"): |
|
continue |
|
if new_key.startswith("layers"): |
|
layer_id = int(new_key.split(".")[1]) |
|
if layer_id >= config.quantize_position: |
|
continue |
|
state_dict[new_key] = f.get_tensor(key) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
model.cuda() |
|
return model |
|
|
|
|
|
_resample_buffer: dict[int, torchaudio.transforms.Resample] = {} |
|
|
|
|
|
def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts): |
|
with torch.no_grad(): |
|
audios, indices = [], [] |
|
for idx, utt in enumerate(utts): |
|
if isinstance(utt, tuple): |
|
audio, sample_rate = utt |
|
else: |
|
audio, sample_rate = torchaudio.load(utt) |
|
audio = audio.cuda() |
|
if sample_rate != 16000: |
|
if sample_rate not in _resample_buffer: |
|
_resample_buffer[sample_rate] = torchaudio.transforms.Resample( |
|
orig_freq=sample_rate, |
|
new_freq=16000 |
|
).to('cuda') |
|
audio = _resample_buffer[sample_rate](audio) |
|
|
|
|
|
audio = audio[0] |
|
audio = audio.cpu().numpy() |
|
time_step = 0 |
|
while time_step * 16000 < audio.shape[0]: |
|
audio_segment = audio[time_step * 16000: (time_step + 30) * 16000] |
|
audios.append(audio_segment) |
|
indices.append(idx) |
|
time_step += 30 |
|
pooling_kernel_size = model.config.pooling_kernel_size or 1 |
|
stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length |
|
all_speech_tokens = [[] for _ in range(len(utts))] |
|
batch_size = 128 |
|
for start in range(0, len(audios), batch_size): |
|
features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000, |
|
return_attention_mask=True, return_tensors="pt", device='cuda', |
|
padding="longest", pad_to_multiple_of=stride) |
|
features = features.to(device="cuda") |
|
outputs = model(**features) |
|
speech_tokens = outputs.quantized_token_ids |
|
attention_mask = features.attention_mask[:, ::model.conv1.stride[0] * model.conv2.stride[0]] |
|
attention_mask = attention_mask[:, ::model.config.pooling_kernel_size] |
|
assert attention_mask.shape == speech_tokens.shape |
|
for i in range(len(speech_tokens)): |
|
idx = indices[start + i] |
|
speech_token = speech_tokens[i][attention_mask[i].bool()].tolist() |
|
all_speech_tokens[idx].extend(speech_token) |
|
return all_speech_tokens |
|
|