|
#!/usr/bin/env bash |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if [ -z "${BASH_VERSION}" ]; then |
|
echo "Please use bash to run this script." >&2 |
|
exit 1 |
|
fi |
|
|
|
set -x |
|
|
|
SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" |
|
ROOT_DIR="$(dirname "${SCRIPT_DIR}")" |
|
export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" |
|
export LOGLEVEL="${LOGLEVEL:-WARNING}" |
|
|
|
MODEL_NAME_OR_PATH="cerebras/btlm-3b-8k-base" |
|
OUTPUT_DIR="${ROOT_DIR}/output/sft" |
|
ZERO_STAGE=2 |
|
while [[ "$#" -gt 0 ]]; do |
|
arg="$1" |
|
shift |
|
case "${arg}" in |
|
--model_name_or_path) |
|
MODEL_NAME_OR_PATH="$1" |
|
shift |
|
;; |
|
--model_name_or_path=*) |
|
MODEL_NAME_OR_PATH="${arg#*=}" |
|
;; |
|
--output_dir) |
|
OUTPUT_DIR="$1" |
|
shift |
|
;; |
|
--output_dir=*) |
|
OUTPUT_DIR="${arg#*=}" |
|
;; |
|
--zero_stage) |
|
ZERO_STAGE="$1" |
|
shift |
|
;; |
|
--zero_stage=*) |
|
ZERO_STAGE="${arg#*=}" |
|
;; |
|
*) |
|
echo "Unknown parameter passed: '${arg}'" >&2 |
|
exit 1 |
|
;; |
|
esac |
|
done |
|
|
|
mkdir -p "${OUTPUT_DIR}" |
|
OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" |
|
if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then |
|
echo '*' >"${OUTPUT_DIR}/.gitignore" |
|
fi |
|
|
|
cp -f "$0" "${OUTPUT_DIR}/script.sh" |
|
|
|
if [[ -z "${WANDB_API_KEY}" ]]; then |
|
export WANDB_MODE="offline" |
|
fi |
|
|
|
MASTER_PORT_START=10000 |
|
MASTER_PORT_END=65535 |
|
MASTER_PORT="$( |
|
comm -23 \ |
|
<(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ |
|
<(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | |
|
shuf | head -n 1 |
|
)" |
|
|
|
exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) |
|
|
|
deepspeed --num_nodes=1 --num_gpus=8 \ |
|
--master_port "${MASTER_PORT}" \ |
|
--module safe_rlhf.finetune \ |
|
--train_datasets bt \ |
|
--model_name_or_path "${MODEL_NAME_OR_PATH}" \ |
|
--max_length 8092 \ |
|
--trust_remote_code True \ |
|
--epochs 16 \ |
|
--per_device_train_batch_size 8 \ |
|
--per_device_eval_batch_size 2 \ |
|
--gradient_accumulation_steps 1 \ |
|
--gradient_checkpointing \ |
|
--learning_rate 4.7e-6 \ |
|
--lr_scheduler_type cosine \ |
|
--num_warmup_steps 20 \ |
|
--weight_decay 0.0 \ |
|
--seed 42 \ |
|
--output_dir "${OUTPUT_DIR}" \ |
|
--log_type wandb \ |
|
--log_project BT-Training \ |
|
--zero_stage "${ZERO_STAGE}" \ |
|
--bf16 True \ |
|
--tf32 True |
|
|