More training details?

#1
by mutiann - opened

I'm recently fine-tuning W2V2 on TEDLIUM v3 as well, while I could only reach ~20% WER, even though I used your hyperparameters and reach similar training losses. Is there any special trick during training? Is it possible for you to share the code or more details?

Thanks in advance!

Hey @mutiann !

How many train steps did you train for? This model was trained over 50k train steps. The WER was still decreasing after 50k train steps, so if you had longer you could probably go even better. Also what pre-processing steps did you employ? I'd highly recommend removing the <unk> token from the training transcriptions: https://github.com/sanchit-gandhi/seq2seq-speech/blob/cfc6d73959486f5bd71c623ddd95843d62f5a614/run_flax_speech_recognition_ctc.py#L986
These <unk> tokens only appear in the train split (not the dev or test) and are commonly removed in the literature.

The script used to train the model was: https://github.com/sanchit-gandhi/seq2seq-speech/blob/main/run_flax_speech_recognition_ctc.py
It was trained on a TPU v3-8 for 50k train steps (~28 hours), training logs here: https://wandb.ai/sanchit-gandhi/tedlium/runs/10c85yc4?workspace=user-sanchit-gandhi

The command used to train the model:

python run_flax_speech_recognition_ctc.py --model_name_or_path=speech-seq2seq/flax-wav2vec2-large-lv60-scan --tokenizer_name=sanchit-gandhi/wav2vec2_ctc_tedlium_tokenizer --dataset_name=LIUM/tedlium --dataset_config_name=release3 --train_split_name=train --eval_split_name=validation --test_split_name=test --text_column_name=text --hidden_dropout=0.2 --activation_dropout=0.2 --feat_proj_dropout=0.2 --output_dir=./flax-wav2vec2-ctc-tedlium-hidden-activation-featproj-dropout-0.2 --wandb_project=tedlium --wandb_name=flax-wav2vec2-ctc-tedlium-hidden-activation-featproj-dropout-0.2 --dataset_cache_dir=/home/sanchitgandhi/cache/huggingface/datasets --max_steps=50000 --save_steps=10000 --eval_steps=10000 --learning_rate=3e-4 --logging_steps=25 --warmup_steps=5000 --preprocessing_num_workers=1 --do_train --do_eval --do_predict --overwrite_output_dir --gradient_checkpointing --freeze_feature_encoder --push_to_hub --use_auth_token
sanchit-gandhi changed discussion status to closed

Thank you!

By removing UNKs the WER decreases to ~15%. It does matter!

I feel that the batch size might be an issue, as you use a total 64 batch size, while I don't have so many GPUs :( I'm now limiting each batch to a total 2.4e6, which corresponds to a batch size of 24.6 in average.
How much impact do you think the batch size could have?

Great, that's good to hear!

In Flax, I struggled to get training to work with a batch size of 16. It fared better with 32+. Could you employ gradient accumulation to bump your effective batch size?

sanchit-gandhi changed discussion status to open

I tried that and with a ~72 effective batch size, WER reduced to ~14%...
I guess I should further check your settings and see if I miss anything...

Feel free to send though your code/repo and I can take a quick look over to see if anything jumps out!

Note that number of train steps is extremely important too. I trained for 50k train steps with an effective batch size of 64, equivalent to 3200k training examples. After this time the eval loss and eval WER were still decreasing, so training for longer is also an option

Screenshot 2022-07-29 at 10.53.27.png

Yes I see that and I trained like 100K steps. I also checked your logs and my WER is much higher since the first evaluation at 10K steps (17% vs 14%)...

Sure, if you send your code/repo over I can take a quick look!

Hi I'm coming back again :(
I'm afraid that it will bother you too much time to read the code...The core part is available here: https://1drv.ms/u/s!AukS30WbNwYfgaKJBPDSvIJ3Gx9rjzA?e=SvSRSG
It might be a bit messed up because for debugging purposes I just plugged part of your data processing code into mine, to ensure that in your code and my code the model gets the exact same input in each step, but code other than dataloader.py should be readable. However my result doesn't change much and is still much lower than yours, so the problem shouldn't come from differences in input data but inside the model or optimization.
I'm still trying to determine where is the difference. Is it possible to extract intermediate values, like outputs of each layer? I read some docs of flax (like https://flax.readthedocs.io/en/latest/guides/extracting_intermediates.html) but using their methods only gives some JVPTracer objects but not true intermediate values. Sorry I've never used flax before :(

I've just fixed a bug in my implementation and it works now :)
Thank you so much for all your aid!

mutiann changed discussion status to closed

Amazing, very happy to hear that @mutiann ! Best of luck!

Sign up or log in to comment