Commit
·
b839dd6
1
Parent(s):
cae4858
add config for training
Browse files
main.py
CHANGED
|
@@ -94,15 +94,15 @@ if __name__ == "__main__":
|
|
| 94 |
os.makedirs(cache_processing_dataset_folder)
|
| 95 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
| 96 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
| 97 |
-
num_epochs =
|
| 98 |
|
| 99 |
training_args = TrainingArguments(
|
| 100 |
output_dir=checkpoint_path,
|
| 101 |
# fp16=True,
|
| 102 |
group_by_length=True,
|
| 103 |
-
per_device_train_batch_size=
|
| 104 |
-
per_device_eval_batch_size=
|
| 105 |
-
gradient_accumulation_steps=
|
| 106 |
num_train_epochs=1, # each epoch per shard data
|
| 107 |
logging_steps=1,
|
| 108 |
learning_rate=1e-4,
|
|
@@ -146,7 +146,7 @@ if __name__ == "__main__":
|
|
| 146 |
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
| 147 |
'cache-train-shard-{}.arrow'.format(
|
| 148 |
train_dataset_shard_idx))
|
| 149 |
-
).shard(1000, 0) # Remove shard split when train
|
| 150 |
# load test shard subset
|
| 151 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 152 |
'shard_{}'.format(test_dataset_shard_idx)),
|
|
|
|
| 94 |
os.makedirs(cache_processing_dataset_folder)
|
| 95 |
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
|
| 96 |
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
|
| 97 |
+
num_epochs = 5000
|
| 98 |
|
| 99 |
training_args = TrainingArguments(
|
| 100 |
output_dir=checkpoint_path,
|
| 101 |
# fp16=True,
|
| 102 |
group_by_length=True,
|
| 103 |
+
per_device_train_batch_size=16,
|
| 104 |
+
per_device_eval_batch_size=16,
|
| 105 |
+
gradient_accumulation_steps=8,
|
| 106 |
num_train_epochs=1, # each epoch per shard data
|
| 107 |
logging_steps=1,
|
| 108 |
learning_rate=1e-4,
|
|
|
|
| 146 |
cache_file_name=os.path.join(cache_processing_dataset_folder,
|
| 147 |
'cache-train-shard-{}.arrow'.format(
|
| 148 |
train_dataset_shard_idx))
|
| 149 |
+
) # .shard(1000, 0) # Remove shard split when train
|
| 150 |
# load test shard subset
|
| 151 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
| 152 |
'shard_{}'.format(test_dataset_shard_idx)),
|