Commit
·
7f3b4fd
1
Parent(s):
f4bad18
Update README.md
Browse files
README.md
CHANGED
@@ -25,45 +25,30 @@ See this GitHub Repo [cantonese-selfish-project](https://github.com/scottykwok/c
|
|
25 |
|
26 |
# Usage
|
27 |
```python
|
28 |
-
import
|
29 |
import torch
|
30 |
-
import torchaudio
|
31 |
from datasets import load_dataset
|
32 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
33 |
-
import sys
|
34 |
|
35 |
-
#
|
36 |
-
|
37 |
-
|
38 |
-
wav_file = sys.argv[1]
|
39 |
-
except:
|
40 |
-
print("Please provide an input wav filename ")
|
41 |
-
exit(-1)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
print("Input Audio:" , wav_file)
|
46 |
-
print("-"* 20)
|
47 |
|
48 |
-
#
|
49 |
-
processor =
|
50 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_id)
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
input_array = speech_to_array(wav_file)
|
57 |
|
58 |
-
#
|
59 |
-
|
|
|
|
|
|
|
60 |
|
61 |
-
# inference
|
62 |
-
with torch.no_grad():
|
63 |
-
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
|
64 |
-
predicted_ids = torch.argmax(logits, dim=-1)
|
65 |
-
pred = processor.batch_decode(predicted_ids)
|
66 |
-
print("-"* 20)
|
67 |
-
print("Prediction:", pred)
|
68 |
-
print("-"* 20)
|
69 |
```
|
|
|
25 |
|
26 |
# Usage
|
27 |
```python
|
28 |
+
import soundfile as sf
|
29 |
import torch
|
|
|
30 |
from datasets import load_dataset
|
31 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
|
|
32 |
|
33 |
+
# load pretrained model
|
34 |
+
processor = Wav2Vec2Processor.from_pretrained("scottykwok/wav2vec2-large-xlsr-cantonese")
|
35 |
+
model = Wav2Vec2ForCTC.from_pretrained("scottykwok/wav2vec2-large-xlsr-cantonese")
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
# load audio - must be 16kHz mono
|
38 |
+
audio_input, sample_rate = sf.read('audio.wav')
|
|
|
|
|
39 |
|
40 |
+
# pad input values and return pt tensor
|
41 |
+
input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
|
|
|
42 |
|
43 |
+
# INFERENCE
|
44 |
+
# retrieve logits & take argmax
|
45 |
+
logits = model(input_values).logits
|
46 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
|
|
47 |
|
48 |
+
# transcribe
|
49 |
+
transcription = processor.decode(predicted_ids[0])
|
50 |
+
print("-" *20)
|
51 |
+
print("Transcription:\n", transcription.lower())
|
52 |
+
print("-" *20)
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
```
|