File size: 1,447 Bytes
a1be16b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
#!/bin/bash
accelerate launch --multi_gpu --mixed_precision=bf16 --num_processes=2 run_distillation_pt.py \
--model_name_or_path distil-whisper/large-32-2 \
--teacher_model_name_or_path openai/whisper-large-v2 \
--train_dataset_config_name all+all+all+l \
--train_dataset_samples 2.9+10.4+14.9+226.6 \
--train_dataset_name librispeech_asr+librispeech_asr+librispeech_asr+gigaspeech-l \
--train_split_name train.clean.100+train.clean.360+train.other.500+train \
--eval_dataset_name librispeech_asr+librispeech_asr+gigaspeech-l \
--eval_dataset_config_name all+all+l \
--eval_split_name validation.clean+validation.other+validation \
--eval_text_column_name text+text+text \
--eval_steps 2500 \
--save_steps 2500 \
--warmup_steps 50 \
--learning_rate 0.0001 \
--lr_scheduler_type constant_with_warmup \
--logging_steps 25 \
--save_total_limit 1 \
--max_steps 10000 \
--wer_threshold 10 \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 2 \
--per_device_eval_batch_size 64 \
--dataloader_num_workers 16 \
--cache_dir /fsx/sanchit/cache \
--dataset_cache_dir /fsx/sanchit/cache \
--dtype bfloat16 \
--output_dir ./ \
--wandb_project distil-whisper-training \
--do_train \
--do_eval \
--gradient_checkpointing \
--overwrite_output_dir \
--predict_with_generate \
--freeze_encoder \
--streaming
|