Forcing decoder to transcribe English is not working
Hi there!
I am following the sample code in README to transcribe a couple of English audios, but the forced decoder is not working, I am getting transcriptions in other languages. Here is my code:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import Dataset, Audio
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
ds = Dataset.from_dict({"audio": ['<audio_path>']}).cast_column("audio", Audio())
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
# generate token ids
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
transcription
And the result:
['<|startoftranscript|><|bo|><|transcribe|><|notimestamps|> [...] <|endoftext|>']
Unfortunately, I cannot share the audio file because it is private.
Some recent changes are causing this problem because I tested 3 weeks ago and was getting ~30 WER and now it is ~70 WER which is caused by transcribing to other languages.
@ArthurZ
the change to .generate()
was backwards compatible with the forced_decoder_ids
no? Maybe you could take a look here!
Any updates on this bug?
Hey @lucas-aixplain , thanks for flagging this, I was able to reproduce:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
# set the forced ids
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="transcribe")
# generate token ids
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
print(transcription)
Print Output:
['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>']
Opened an issue on Transformers to track: https://github.com/huggingface/transformers/issues/21937