MartaKozina commited on
Commit
84400b8
·
1 Parent(s): 9942820

Upload wav2vec2.py

Browse files
Files changed (1) hide show
  1. wav2vec2.py +60 -0
wav2vec2.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import torch
3
+ import torchaudio
4
+
5
+ matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
6
+
7
+ torch.random.manual_seed(0)
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # print(torch.__version__)
11
+ # print(torchaudio.__version__)
12
+ # print(device)
13
+ #
14
+ # SPEECH_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/source-16k/train/sp0307/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav" # noqa: E501
15
+ # SPEECH_FILE = "_assets/speech.wav"
16
+ #
17
+ # if not os.path.exists(SPEECH_FILE):
18
+ # os.makedirs("_assets", exist_ok=True)
19
+ # with open(SPEECH_FILE, "wb") as file:
20
+ # file.write(requests.get(SPEECH_URL).content)
21
+
22
+
23
+ class GreedyCTCDecoder(torch.nn.Module):
24
+ def __init__(self, labels, blank=0):
25
+ super().__init__()
26
+ self.labels = labels
27
+ self.blank = blank
28
+
29
+ def forward(self, emission: torch.Tensor) -> str:
30
+ """Given a sequence emission over labels, get the best path string
31
+ Args:
32
+ emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
33
+
34
+ Returns:
35
+ str: The resulting transcript
36
+ """
37
+ indices = torch.argmax(emission, dim=-1) # [num_seq,]
38
+ indices = torch.unique_consecutive(indices, dim=-1)
39
+ indices = [i for i in indices if i != self.blank]
40
+ return "".join([self.labels[i] for i in indices])
41
+
42
+ def predict(file):
43
+ bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
44
+ model = bundle.get_model().to(device)
45
+
46
+ waveform, sample_rate = torchaudio.load(file)
47
+ waveform = waveform.to(device)
48
+
49
+ if sample_rate != bundle.sample_rate:
50
+ waveform = torchaudio.functional.resample(waveform, sample_rate, bundle.sample_rate)
51
+
52
+ with torch.inference_mode():
53
+ features, _ = model.extract_features(waveform)
54
+ with torch.inference_mode():
55
+ emission, _ = model(waveform)
56
+
57
+ decoder = GreedyCTCDecoder(labels=bundle.get_labels())
58
+ transcript = decoder(emission[0])
59
+ return transcript
60
+