Fix mixup of `<pad>` and `<s>` tokens in vocab
When using this model, it outputs many <s>
-tokens, including in the middle of words. You can observe this by running locally, or by using the widget on this page.
It seems to be fixed by switching the vocab ids of <s>
and <pad>
.
Other GroNLP-models also seem affected by this, for example https://huggingface.co/GroNLP/wav2vec2-dutch-large-ft-cgn
Thank you for your interest in this model. I checked and the problem is not that <s>
and <pad>
are swapped, but that the output contains spurious <s>
output. I have noticed this myself when we made this model. The <s>
token should be part of the output, but ideally only as the first token. My hypothesis is that the model might confuse silences with the small silences at the start of audio in the training data. However, it does depend on your exact audio whether these spurious <s>
tokens appear.
The <pad>
token should never be part of any output. This token is only used during training, and is masked. So the model will not give this token as output.
The general solution to this problem at inference is to pass skip_special_tokens=True
to your decode call (see the docs). Maybe I can enable this as well in the inference widget, but I'll have to look into that.
Let me know if this does not solve your problem or if you need more help.
Thank you very much for the super-fast response and the suggestion!
The reason I suspected a switch in the vocab is because the ids in vocab.json
don't correspond to the ids in config.json
:
"bos_token_id": 1,
"eos_token_id": 2,
"pad_token_id": 0,
However, you're right that it's even weirder to have <pad>
tokens in the middle of the output than it it to have <s>
.
As to your suggestion: I tried skip_special_tokens=True
before, because I agree that it should be the cleanest way to get readable output. However, in this model I observe that this removes any and all occurrences of repeated letters. This best shown by this example of this example from Commonvoice-nl (sorry, this is part of our test data, I don't have the original filename available right now).
tokenizer.decode(predicted_ids[0])
"<s>J<s>A<s> <s>GOE<s>D<s>E<s>M<s>OR<s>G<s>EN<s> <s>M<s>A<s>AR<s>T<s>E<s>N<s> P<s>A<s>S<s>C<s>A<s>L<s> <s>J<s>A<s>Z<s>E<s>K<s>E<s>R<s> HE<s>EFT <s>IE D<s>E W<s>AN<s>DEL<s>SCHOE<s>N<s>EN <s>A<s>A<s>N<s> <s>W<s>AN<s>T<s> <s>HET <s>IS N<s>OG<s> E<s>EN STU<s>K<s>J<s>E<s> <s>S<s><unk><s>L<s> <s>HOE<s>V<s>E<s>EL<s> <s>K<s>I<s>L<s>O<s>M<s>E<s>T<s>E<s>R<s>S<s> <s>Z<s>IJ<s>N<s> <s>'<s>T<s> <s>T<s>O<s>T <s>IN D<s>EN <s>H<s>A<s>A<s>G<s>? <s>"
tokenizer.decode(predicted_ids[0]).replace('<s>', '')
"JA GOEDEMORGEN MAARTEN PASCAL JAZEKER HEEFT IE DE WANDELSCHOENEN AAN WANT HET IS NOG EEN STUKJE S<unk>L HOEVEEL KILOMETERS ZIJN 'T TOT IN DEN HAAG? "
tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
"JA GOEDEMORGEN MARTEN PASCAL JAZEKER HEFT IE DE WANDELSCHOENEN AN WANT HET IS NOG EN STUKJE SL HOEVEL KILOMETERS ZIJN 'T TOT IN DEN HAG?"
Now I suspect that this is somehow a bug in transformers
that is being triggered here, because I don't observe this behavior in all models. But it does preclude me from using skip_special_tokens
for now.
Thank you for this clarification! The config and vocab indeed contradict eachother. So you might be correct that they should be switched, depending on when the mixup occurred. I suspect that it has to do with the conversion from fairseq
to transformers
. I cannot test this right now, but could you share with me what the exact output would be with your changed vocab? With and without skip_special_tokens=True
?
The output in the reply above was with the code as-is.
If you apply this PR the output becomes:
tokenizer.decode(predicted_ids[0])
"JA GOEDEMORGEN MAARTEN PASCAL JAZEKER HEEFT IE DE WANDELSCHOENEN AAN WANT HET IS NOG EEN STUKJE S<unk>L HOEVEEL KILOMETERS ZIJN 'T TOT IN DEN HAAG?"
tokenizer.decode(predicted_ids[0]).replace('<s>', '')
"JA GOEDEMORGEN MAARTEN PASCAL JAZEKER HEEFT IE DE WANDELSCHOENEN AAN WANT HET IS NOG EEN STUKJE S<unk>L HOEVEEL KILOMETERS ZIJN 'T TOT IN DEN HAAG?"
tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
"JA GOEDEMORGEN MARTEN PASCAL JAZEKER HEFT IE DE WANDELSCHOENEN AN WANT HET IS NOG EN STUKJE SL HOEVEL KILOMETERS ZIJN 'T TOT IN DEN HAG?"
So changing the ids does not change the weird behavior of setting skip_special_tokens=True
, but it does make the default output more readable. There could still be something we're missing here though, I'll update if I find out more.
By the way I've written a stand-alone script that demonstrates the behavior, rather than debugging within a larger repo.
import torch
from datasets import load_dataset, Audio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
model = Wav2Vec2ForCTC.from_pretrained("GroNLP/wav2vec2-large-xlsr-53-ft-cgn")
processor = Wav2Vec2Processor.from_pretrained("GroNLP/wav2vec2-large-xlsr-53-ft-cgn")
# load first sample of Dutch common_voice
dataset = load_dataset("common_voice", "nl", split="test", streaming=True)
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
sample = next(iter(dataset))
# forward sample through model to get greedily predicted transcription ids
input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)[0]
# decode predicted tokens
print(processor.decode(predicted_ids))
print(processor.decode(predicted_ids).replace('<s>', ''))
print(processor.decode(predicted_ids, skip_special_tokens=True))
print()
print(processor.decode(predicted_ids, group_tokens=False))
print(processor.decode(predicted_ids, skip_special_tokens=True, group_tokens=False))
Original output:
<s>H<s>ET<s> <s>C<s>O<s>N<s>T<s>AI<s>N<s>E<s>SCH<s>I<s>P<s> <s>L<s>A<s>G<s> <s>A<s>AN<s>G<s>E<s>M<s>E<s>ER<s>D<s> <s>IN D<s>E <s>H<s>A<s>V<s>EN<s> <s>
HET CONTAINESCHIP LAG AANGEMEERD IN DE HAVEN
HET CONTAINESCHIP LAG ANGEMERD IN DE HAVEN
<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>H<s>ET<s> <s>C<s><s>O<s>N<s><s>T<s><s><s>AI<s><s>N<s><s>E<s><s><s><s>SCH<s><s>I<s><s>P<s><s> <s>L<s><s>A<s>G<s> <s><s>A<s><s>AN<s><s><s>G<s>E<s><s>M<s>E<s><s>ER<s><s>D<s> <s>IN D<s>E <s>H<s><s>AA<s><s><s>V<s><s><s>EN<s><s><s> <s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
HET CONTAINESCHIP LAG AANGEMEERD IN DE HAAVEN
PR output:
HET CONTAINESCHIP LAG AANGEMEERD IN DE HAVEN
HET CONTAINESCHIP LAG AANGEMEERD IN DE HAVEN
HET CONTAINESCHIP LAG ANGEMERD IN DE HAVEN
HET CONTAINESCHIP LAG AANGEMEERD IN DE HAAVEN
HET CONTAINESCHIP LAG AANGEMEERD IN DE HAAVEN
I'm new to the whole ASR-field so I've learned a few new things yesterday:
- The raw output of the CTC-algorithm includes a blank token, and repeated tokens (including blank tokens) need collapsing to form the final output: https://distill.pub/2017/ctc/
- In Fine-Tune Wav2Vec2 for English ASR with 🤗 Transformers it is heavily implied that the
<pad>
token is used to fill the role of the blank token - You can ask the tokenizer to decode the raw output by specifying
group_tokens=False
.
I've added the raw output to the example above. There you can clearly see the tokenizer outputs a lot of <s>
tokens, which get collapsed by the CTC-algorithm into spurious single instances. To me it looks consistent with the number of blank tokens I would expect.
Thank you for researching this issue and giving this extensive help! You seem to be right and the issue should indeed be solved by your patch. I will merge this and apply the change to the other models. I checked the original fairseq models, and it's indeed the vocab that needs to be changed, not the config. Thanks!