scottykwok commited on
Commit
7f3b4fd
·
1 Parent(s): f4bad18

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +17 -32
README.md CHANGED
@@ -25,45 +25,30 @@ See this GitHub Repo [cantonese-selfish-project](https://github.com/scottykwok/c
25
 
26
  # Usage
27
  ```python
28
- import time
29
  import torch
30
- import torchaudio
31
  from datasets import load_dataset
32
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
33
- import sys
34
 
35
- # inputs
36
- model_id = "scottykwok/wav2vec2-large-xlsr-cantonese"
37
- try:
38
- wav_file = sys.argv[1]
39
- except:
40
- print("Please provide an input wav filename ")
41
- exit(-1)
42
 
43
- print("-"* 20)
44
- print("Model ID:" , model_id)
45
- print("Input Audio:" , wav_file)
46
- print("-"* 20)
47
 
48
- # load model and tokenizer
49
- processor = Wav2Vec2Processor.from_pretrained(model_id)
50
- model = Wav2Vec2ForCTC.from_pretrained(model_id)
51
 
52
- # read audio to numpy
53
- def speech_to_array(path):
54
- speech_array, sampling_rate = torchaudio.load(path)
55
- return speech_array.squeeze().numpy()
56
- input_array = speech_to_array(wav_file)
57
 
58
- # tokenize
59
- inputs = processor([input_array], sampling_rate=16_000, return_tensors="pt", padding=True)
 
 
 
60
 
61
- # inference
62
- with torch.no_grad():
63
- logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
64
- predicted_ids = torch.argmax(logits, dim=-1)
65
- pred = processor.batch_decode(predicted_ids)
66
- print("-"* 20)
67
- print("Prediction:", pred)
68
- print("-"* 20)
69
  ```
 
25
 
26
  # Usage
27
  ```python
28
+ import soundfile as sf
29
  import torch
 
30
  from datasets import load_dataset
31
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
32
 
33
+ # load pretrained model
34
+ processor = Wav2Vec2Processor.from_pretrained("scottykwok/wav2vec2-large-xlsr-cantonese")
35
+ model = Wav2Vec2ForCTC.from_pretrained("scottykwok/wav2vec2-large-xlsr-cantonese")
 
 
 
 
36
 
37
+ # load audio - must be 16kHz mono
38
+ audio_input, sample_rate = sf.read('audio.wav')
 
 
39
 
40
+ # pad input values and return pt tensor
41
+ input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values
 
42
 
43
+ # INFERENCE
44
+ # retrieve logits & take argmax
45
+ logits = model(input_values).logits
46
+ predicted_ids = torch.argmax(logits, dim=-1)
 
47
 
48
+ # transcribe
49
+ transcription = processor.decode(predicted_ids[0])
50
+ print("-" *20)
51
+ print("Transcription:\n", transcription.lower())
52
+ print("-" *20)
53
 
 
 
 
 
 
 
 
 
54
  ```