MODEL="t5-base-dutch" MODEL_DIR="${HOME}/${MODEL}" mkdir -p "${MODEL_DIR}/runs" # T5 paper lr 0.01 with batch size 128 # We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2 # Warmup steps is set to 6% of the training steps ./run_t5_mlm_flax_custom_dataset.py \ --output_dir="${MODEL_DIR}" \ --model_type="t5" \ --config_name="flax-community/${MODEL}" \ --tokenizer_name="${MODEL_DIR}" \ --preprocessing_num_workers="96" \ --do_train --do_eval \ --adafactor \ --max_seq_length="512" \ --per_device_train_batch_size="32" \ --per_device_eval_batch_size="32" \ --learning_rate="1e-2" \ --dtype="bfloat16" \ --overwrite_output_dir \ --num_train_epochs="1" \ --logging_steps="50" \ --save_steps="300" \ --eval_steps="1000000" \ --push_to_hub #git add pytorch_model.bin #git commit -m "Update pytorch model after training" #git push origin main # --learning_rate="5e-3" \ # --gradient_accumulation_steps="2" \ # --resume_from_checkpoint="${MODEL_DIR}/ckpt-3300" \