patrickvonplaten sanchit-gandhi HF staff commited on
Commit
12c969c
1 Parent(s): 89d22e9

Fix examples: input_ids -> input_features (#2)

Browse files

- Fix examples: input_ids -> input_features (bb4003f2f80a17a9f6b736fae47f65aa63ba3eb2)
- Update README.md (2e3b22f123401d6a3dfef7756384ca5cfc352807)


Co-authored-by: Sanchit Gandhi <sanchit-gandhi@users.noreply.huggingface.co>

Files changed (1) hide show
  1. README.md +8 -9
README.md CHANGED
@@ -101,7 +101,7 @@ input_features = processor(
101
  sampling_rate=16_000,
102
  return_tensors="pt"
103
  ).input_features # Batch size 1
104
- generated_ids = model.generate(input_ids=input_features)
105
 
106
  transcription = processor.batch_decode(generated_ids)
107
  ```
@@ -112,29 +112,28 @@ The following script shows how to evaluate this model on the [LibriSpeech](https
112
  *"clean"* and *"other"* test dataset.
113
 
114
  ```python
115
- from datasets import load_dataset, load_metric
 
116
  from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
117
 
118
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
119
- wer = load_metric("wer")
120
 
121
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to("cuda")
122
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr", do_upper_case=True)
123
 
124
- librispeech_eval = librispeech_eval.map(map_to_array)
125
-
126
  def map_to_pred(batch):
127
  features = processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
128
  input_features = features.input_features.to("cuda")
129
  attention_mask = features.attention_mask.to("cuda")
130
 
131
- gen_tokens = model.generate(input_ids=input_features, attention_mask=attention_mask)
132
- batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)
133
  return batch
134
 
135
- result = librispeech_eval.map(map_to_pred, batched=True, batch_size=8, remove_columns=["speech"])
136
 
137
- print("WER:", wer(predictions=result["transcription"], references=result["text"]))
138
  ```
139
 
140
  *Result (WER)*:
 
101
  sampling_rate=16_000,
102
  return_tensors="pt"
103
  ).input_features # Batch size 1
104
+ generated_ids = model.generate(input_features=input_features)
105
 
106
  transcription = processor.batch_decode(generated_ids)
107
  ```
 
112
  *"clean"* and *"other"* test dataset.
113
 
114
  ```python
115
+ from datasets import load_dataset
116
+ from evaluate import load
117
  from transformers import Speech2TextForConditionalGeneration, Speech2TextProcessor
118
 
119
  librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") # change to "other" for other test dataset
120
+ wer = load("wer")
121
 
122
  model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr").to("cuda")
123
  processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr", do_upper_case=True)
124
 
 
 
125
  def map_to_pred(batch):
126
  features = processor(batch["audio"]["array"], sampling_rate=16000, padding=True, return_tensors="pt")
127
  input_features = features.input_features.to("cuda")
128
  attention_mask = features.attention_mask.to("cuda")
129
 
130
+ gen_tokens = model.generate(input_features=input_features, attention_mask=attention_mask)
131
+ batch["transcription"] = processor.batch_decode(gen_tokens, skip_special_tokens=True)[0]
132
  return batch
133
 
134
+ result = librispeech_eval.map(map_to_pred, remove_columns=["audio"])
135
 
136
+ print("WER:", wer.compute(predictions=result["transcription"], references=result["text"]))
137
  ```
138
 
139
  *Result (WER)*: