stupidog04 commited on
Commit
d6c86e9
·
1 Parent(s): 62716ae

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -46
pipeline.py DELETED
@@ -1,46 +0,0 @@
1
- import numpy as np
2
- from typing import Dict
3
-
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
- from pyctcdecode import Alphabet, BeamSearchDecoderCTC
6
-
7
- class PreTrainedPipeline():
8
- def __init__(self, path):
9
- """
10
- Initialize model
11
- """
12
- self.processor = Wav2Vec2Processor.from_pretrained(path)
13
- self.model = Wav2Vec2ForCTC.from_pretrained(path)
14
- vocab_list = list(self.processor.tokenizer.get_vocab().keys())
15
-
16
- # convert ctc blank character representation
17
- vocab_list[0] = ""
18
-
19
- # replace special characters
20
- vocab_list[1] = "⁇"
21
- vocab_list[2] = "⁇"
22
- vocab_list[3] = "⁇"
23
-
24
- # convert space character representation
25
- vocab_list[4] = " "
26
-
27
- alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=0)
28
-
29
- self.decoder = BeamSearchDecoderCTC(alphabet)
30
- self.sampling_rate = 16000
31
-
32
-
33
- def __call__(self, inputs)-> Dict[str, str]:
34
- """
35
- Args:
36
- inputs (:obj:`np.array`):
37
- The raw waveform of audio received. By default at 16KHz.
38
- Return:
39
- A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
40
- the detected text from the input audio.
41
- """
42
- input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
43
- logits = self.model(input_values).logits.cpu().detach().numpy()[0]
44
- return {
45
- "text": self.decoder.decode(logits)
46
- }