import onnxruntime as ort import numpy as np from transformers import AutoProcessor from datasets import load_dataset import soundfile as sf encoder_path = "/home/fxmarty/hf_internship/optimum/whisper_onnx/encoder_model.onnx" decoder_path = "/home/fxmarty/hf_internship/optimum/whisper_onnx/decoder_model.onnx" encoder = ort.InferenceSession(encoder_path, providers=["CPUExecutionProvider"]) decoder = ort.InferenceSession(decoder_path, providers=["CPUExecutionProvider"]) batch_size = 1 encoder_num_attention_heads = 6 decoder_num_attention_heads = 6 hidden_size = 384 encoder_sequence_length = 1500 decoder_max_length = 448 num_hidden_layers = 4 encoder_shape = ( batch_size, encoder_num_attention_heads, encoder_sequence_length, hidden_size // encoder_num_attention_heads, ) decoder_shape = ( batch_size, decoder_num_attention_heads, decoder_max_length, hidden_size // decoder_num_attention_heads, ) processor = AutoProcessor.from_pretrained("/home/fxmarty/hf_internship/optimum/whisper_onnx") ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") idx = 0 sf.write('dummy.wav', ds[idx]["audio"]["array"], 16000) print("audio length", ds[idx]["audio"]["array"].shape) inputs = processor.feature_extractor(ds[idx]["audio"]["array"], return_tensors="np") print("inp shape", inputs["input_features"].shape) encoder_inputs = {"input_features": inputs["input_features"]} encoder_outputs = encoder.run(None, encoder_inputs) result = [] decoder_attention_mask = np.zeros((batch_size, decoder_max_length)).astype(np.int64) # decoder_attention_mask[0, -1] = 1 decoder_attention_mask[0, 0] = 1 position_ids = np.array([[0]]).astype(np.int64) token = 50257 # see whisper config decoder_start_token_id decoder_inputs = { "encoder_outputs": encoder_outputs[0], } for i in range(num_hidden_layers): decoder_inputs[f"past_key_values.{i}.decoder.key"] = np.random.rand(*decoder_shape).astype(np.float32) decoder_inputs[f"past_key_values.{i}.decoder.value"] = np.random.rand(*decoder_shape).astype(np.float32) decoder_inputs[f"past_key_values.{i}.encoder.key"] = np.random.rand(*encoder_shape).astype(np.float32) decoder_inputs[f"past_key_values.{i}.encoder.value"] = np.random.rand(*encoder_shape).astype(np.float32) while True: decoder_inputs["position_ids"] = position_ids decoder_inputs["decoder_attention_mask"] = decoder_attention_mask decoder_inputs["decoder_input_ids"] = np.array([[token]]).astype(np.int64) decoder_outputs = decoder.run(None, decoder_inputs) token = np.argmax(decoder_outputs[0], axis=-1).item() pos_id = position_ids[0][0] decoder_attention_mask[:, pos_id + 1] = 1 for i in range(num_hidden_layers): # This unintuitive indexing corresponds to the ONNX output order. decoder_inputs[f"past_key_values.{i}.decoder.key"] = decoder_outputs[1 + 4 * i] decoder_inputs[f"past_key_values.{i}.decoder.value"] = decoder_outputs[1 + 4 * i + 1] # Prefill the KV cache for cross-attention if pos_id == 0: decoder_inputs[f"past_key_values.{i}.encoder.key"] = decoder_outputs[1 + 4 * i + 2] decoder_inputs[f"past_key_values.{i}.encoder.value"] = decoder_outputs[1 + 4 * i + 3] position_ids += 1 result.append(token) if token == processor.tokenizer.eos_token_id or len(result) == decoder_max_length: break print("result", result) print(processor.tokenizer.decode(result))