whisper-static-shapes-onnx / run_whisper_ort.py
Félix Marty
update
796499f
import onnxruntime as ort
import numpy as np
from transformers import AutoProcessor
from datasets import load_dataset
import soundfile as sf
encoder_path = "/home/fxmarty/hf_internship/optimum/whisper_onnx/encoder_model.onnx"
decoder_path = "/home/fxmarty/hf_internship/optimum/whisper_onnx/decoder_model.onnx"
encoder = ort.InferenceSession(encoder_path, providers=["CPUExecutionProvider"])
decoder = ort.InferenceSession(decoder_path, providers=["CPUExecutionProvider"])
batch_size = 1
encoder_num_attention_heads = 6
decoder_num_attention_heads = 6
hidden_size = 384
encoder_sequence_length = 1500
decoder_max_length = 448
num_hidden_layers = 4
encoder_shape = (
batch_size,
encoder_num_attention_heads,
encoder_sequence_length,
hidden_size // encoder_num_attention_heads,
)
decoder_shape = (
batch_size,
decoder_num_attention_heads,
decoder_max_length,
hidden_size // decoder_num_attention_heads,
)
processor = AutoProcessor.from_pretrained("/home/fxmarty/hf_internship/optimum/whisper_onnx")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
idx = 0
sf.write('dummy.wav', ds[idx]["audio"]["array"], 16000)
print("audio length", ds[idx]["audio"]["array"].shape)
inputs = processor.feature_extractor(ds[idx]["audio"]["array"], return_tensors="np")
print("inp shape", inputs["input_features"].shape)
encoder_inputs = {"input_features": inputs["input_features"]}
encoder_outputs = encoder.run(None, encoder_inputs)
result = []
decoder_attention_mask = np.zeros((batch_size, decoder_max_length)).astype(np.int64)
# decoder_attention_mask[0, -1] = 1
decoder_attention_mask[0, 0] = 1
position_ids = np.array([[0]]).astype(np.int64)
token = 50257 # see whisper config decoder_start_token_id
decoder_inputs = {
"encoder_outputs": encoder_outputs[0],
}
for i in range(num_hidden_layers):
decoder_inputs[f"past_key_values.{i}.decoder.key"] = np.random.rand(*decoder_shape).astype(np.float32)
decoder_inputs[f"past_key_values.{i}.decoder.value"] = np.random.rand(*decoder_shape).astype(np.float32)
decoder_inputs[f"past_key_values.{i}.encoder.key"] = np.random.rand(*encoder_shape).astype(np.float32)
decoder_inputs[f"past_key_values.{i}.encoder.value"] = np.random.rand(*encoder_shape).astype(np.float32)
while True:
decoder_inputs["position_ids"] = position_ids
decoder_inputs["decoder_attention_mask"] = decoder_attention_mask
decoder_inputs["decoder_input_ids"] = np.array([[token]]).astype(np.int64)
decoder_outputs = decoder.run(None, decoder_inputs)
token = np.argmax(decoder_outputs[0], axis=-1).item()
pos_id = position_ids[0][0]
decoder_attention_mask[:, pos_id + 1] = 1
for i in range(num_hidden_layers):
# This unintuitive indexing corresponds to the ONNX output order.
decoder_inputs[f"past_key_values.{i}.decoder.key"] = decoder_outputs[1 + 4 * i]
decoder_inputs[f"past_key_values.{i}.decoder.value"] = decoder_outputs[1 + 4 * i + 1]
# Prefill the KV cache for cross-attention
if pos_id == 0:
decoder_inputs[f"past_key_values.{i}.encoder.key"] = decoder_outputs[1 + 4 * i + 2]
decoder_inputs[f"past_key_values.{i}.encoder.value"] = decoder_outputs[1 + 4 * i + 3]
position_ids += 1
result.append(token)
if token == processor.tokenizer.eos_token_id or len(result) == decoder_max_length:
break
print("result", result)
print(processor.tokenizer.decode(result))