Update scripts to work around collator valueerror. Update weights
Browse files- config.json +2 -0
- flax_model.msgpack +1 -1
- opt_state.msgpack +3 -0
- pytorch_model.bin +1 -1
- run_t5.sh +37 -22
- run_t5_mlm_flax_custom_dataset.py +5 -0
- runs/{Jul11_17-06-36_t1v-n-0e7426e8-w-0/events.out.tfevents.1626023202.t1v-n-0e7426e8-w-0.178001.3.v2 → Jul12_06-43-08_t1v-n-0e7426e8-w-0/events.out.tfevents.1626072193.t1v-n-0e7426e8-w-0.238699.3.v2} +2 -2
- training_state.json +1 -0
config.json
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
{
|
|
|
2 |
"architectures": [
|
3 |
"T5ForConditionalGeneration"
|
4 |
],
|
@@ -50,6 +51,7 @@
|
|
50 |
"prefix": "translate English to Romanian: "
|
51 |
}
|
52 |
},
|
|
|
53 |
"transformers_version": "4.9.0.dev0",
|
54 |
"use_cache": true,
|
55 |
"vocab_size": 32103
|
|
|
1 |
{
|
2 |
+
"_name_or_path": ".",
|
3 |
"architectures": [
|
4 |
"T5ForConditionalGeneration"
|
5 |
],
|
|
|
51 |
"prefix": "translate English to Romanian: "
|
52 |
}
|
53 |
},
|
54 |
+
"torch_dtype": "float32",
|
55 |
"transformers_version": "4.9.0.dev0",
|
56 |
"use_cache": true,
|
57 |
"vocab_size": 32103
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891548548
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c8d5a4eb1275b4c679b148f38edb974772997a3925809f39095204009f83502
|
3 |
size 891548548
|
opt_state.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:97c0ff372805930fa4d7e81ae09094b7daf3cc2c1ba06224fc522a8e672af91a
|
3 |
+
size 1985609
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891650495
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:782edc5c7aa8aa66320a3417abff572760287ee6a7759f1867486d2217563650
|
3 |
size 891650495
|
run_t5.sh
CHANGED
@@ -7,28 +7,42 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
7 |
# T5 paper lr 0.01 with batch size 128
|
8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
|
34 |
#git add pytorch_model.bin
|
@@ -37,3 +51,4 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
37 |
|
38 |
# --gradient_accumulation_steps="2" \
|
39 |
|
|
|
|
7 |
# T5 paper lr 0.01 with batch size 128
|
8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
9 |
|
10 |
+
while true; do
|
11 |
+
|
12 |
+
# Set the seed to random before each run, so date shuffling per epoch is different each run.
|
13 |
+
# This kills reproducibility, but is required as long as during training ValueError can be raised.
|
14 |
+
SEED=$RANDOM
|
15 |
+
|
16 |
+
./run_t5_mlm_flax_custom_dataset.py \
|
17 |
+
--output_dir="${MODEL_DIR}" \
|
18 |
+
--model_type="t5" \
|
19 |
+
--config_name="flax-community/${MODEL}" \
|
20 |
+
--tokenizer_name="${MODEL_DIR}" \
|
21 |
+
--seed="${SEED}" \
|
22 |
+
--preprocessing_num_workers="96" \
|
23 |
+
--do_train --do_eval \
|
24 |
+
--adafactor \
|
25 |
+
--max_seq_length="512" \
|
26 |
+
--per_device_train_batch_size="32" \
|
27 |
+
--per_device_eval_batch_size="32" \
|
28 |
+
--learning_rate="5e-3" \
|
29 |
+
--dtype="bfloat16" \
|
30 |
+
--overwrite_output_dir \
|
31 |
+
--num_train_epochs="3" \
|
32 |
+
--logging_steps="50" \
|
33 |
+
--save_steps="501" \
|
34 |
+
--eval_steps="10000000" \
|
35 |
+
--resume_from_checkpoint="${MODEL_DIR}" \
|
36 |
+
--warmup_steps="3413"
|
37 |
+
|
38 |
+
# \
|
39 |
+
# --push_to_hub
|
40 |
+
|
41 |
+
echo "RESTARTING"
|
42 |
+
sleep 20
|
43 |
+
done
|
44 |
+
#
|
45 |
+
# \
|
46 |
|
47 |
|
48 |
#git add pytorch_model.bin
|
|
|
51 |
|
52 |
# --gradient_accumulation_steps="2" \
|
53 |
|
54 |
+
# --resume_from_checkpoint="${MODEL_DIR}/ckpt-18000" \
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
@@ -432,6 +432,11 @@ def save_checkpoint(model, save_dir, state, with_opt: bool = True):
|
|
432 |
push_to_hub=training_args.push_to_hub,
|
433 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
434 |
)
|
|
|
|
|
|
|
|
|
|
|
435 |
logger.info("checkpoint saved")
|
436 |
|
437 |
|
|
|
432 |
push_to_hub=training_args.push_to_hub,
|
433 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
434 |
)
|
435 |
+
if with_opt:
|
436 |
+
with open(os.path.join(training_args.output_dir, "opt_state.msgpack"), "wb") as f:
|
437 |
+
f.write(to_bytes(state.opt_state))
|
438 |
+
with open(os.path.join(training_args.output_dir, "training_state.json"), "w") as f:
|
439 |
+
json.dump({"step": state.step.item()}, f)
|
440 |
logger.info("checkpoint saved")
|
441 |
|
442 |
|
runs/{Jul11_17-06-36_t1v-n-0e7426e8-w-0/events.out.tfevents.1626023202.t1v-n-0e7426e8-w-0.178001.3.v2 → Jul12_06-43-08_t1v-n-0e7426e8-w-0/events.out.tfevents.1626072193.t1v-n-0e7426e8-w-0.238699.3.v2}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f5f6fcc83f8cf7fac87cc276fa00a02c9ce4e252c6bb69a3988452bed73f67e
|
3 |
+
size 200238
|
training_state.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"step": 15004}
|