File size: 3,516 Bytes
7fad28d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b9c37
7fad28d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796499f
7fad28d
 
 
 
 
 
 
 
 
 
 
 
 
796499f
7fad28d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import onnxruntime as ort
import numpy as np
from transformers import AutoProcessor
from datasets import load_dataset
import soundfile as sf

encoder_path = "/home/fxmarty/hf_internship/optimum/whisper_onnx/encoder_model.onnx"
decoder_path = "/home/fxmarty/hf_internship/optimum/whisper_onnx/decoder_model.onnx"

encoder = ort.InferenceSession(encoder_path, providers=["CPUExecutionProvider"])
decoder = ort.InferenceSession(decoder_path, providers=["CPUExecutionProvider"])

batch_size = 1
encoder_num_attention_heads = 6
decoder_num_attention_heads = 6
hidden_size = 384
encoder_sequence_length = 1500
decoder_max_length = 448
num_hidden_layers = 4

encoder_shape = (
    batch_size,
    encoder_num_attention_heads,
    encoder_sequence_length,
    hidden_size // encoder_num_attention_heads,
)
decoder_shape = (
    batch_size,
    decoder_num_attention_heads,
    decoder_max_length,
    hidden_size // decoder_num_attention_heads,
)

processor = AutoProcessor.from_pretrained("/home/fxmarty/hf_internship/optimum/whisper_onnx")

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

idx = 0
sf.write('dummy.wav', ds[idx]["audio"]["array"], 16000)
print("audio length", ds[idx]["audio"]["array"].shape)
inputs = processor.feature_extractor(ds[idx]["audio"]["array"], return_tensors="np")
print("inp shape", inputs["input_features"].shape)

encoder_inputs = {"input_features": inputs["input_features"]}

encoder_outputs = encoder.run(None, encoder_inputs)

result = []

decoder_attention_mask = np.zeros((batch_size, decoder_max_length)).astype(np.int64)
# decoder_attention_mask[0, -1] = 1
decoder_attention_mask[0, 0] = 1
position_ids = np.array([[0]]).astype(np.int64)

token = 50257  # see whisper config decoder_start_token_id

decoder_inputs = {
    "encoder_outputs": encoder_outputs[0],
}

for i in range(num_hidden_layers):
        decoder_inputs[f"past_key_values.{i}.decoder.key"] = np.random.rand(*decoder_shape).astype(np.float32)
        decoder_inputs[f"past_key_values.{i}.decoder.value"] = np.random.rand(*decoder_shape).astype(np.float32)

        decoder_inputs[f"past_key_values.{i}.encoder.key"] = np.random.rand(*encoder_shape).astype(np.float32)
        decoder_inputs[f"past_key_values.{i}.encoder.value"] = np.random.rand(*encoder_shape).astype(np.float32)


while True:
    decoder_inputs["position_ids"] = position_ids
    decoder_inputs["decoder_attention_mask"] = decoder_attention_mask
    decoder_inputs["decoder_input_ids"] = np.array([[token]]).astype(np.int64)

    decoder_outputs = decoder.run(None, decoder_inputs)

    token = np.argmax(decoder_outputs[0], axis=-1).item()

    pos_id = position_ids[0][0]
    decoder_attention_mask[:, pos_id + 1] = 1

    for i in range(num_hidden_layers):
        # This unintuitive indexing corresponds to the ONNX output order.
        decoder_inputs[f"past_key_values.{i}.decoder.key"] = decoder_outputs[1 + 4 * i]
        decoder_inputs[f"past_key_values.{i}.decoder.value"] = decoder_outputs[1 + 4 * i + 1]

        # Prefill the KV cache for cross-attention
        if pos_id == 0:
            decoder_inputs[f"past_key_values.{i}.encoder.key"] = decoder_outputs[1 + 4 * i + 2]
            decoder_inputs[f"past_key_values.{i}.encoder.value"] = decoder_outputs[1 + 4 * i + 3]

    position_ids += 1

    result.append(token)
    if token == processor.tokenizer.eos_token_id or len(result) == decoder_max_length:
        break

print("result", result)
print(processor.tokenizer.decode(result))