nguyenvulebinh
commited on
Commit
•
cb2b82e
1
Parent(s):
cbf9056
filter wav 10s and new pretrained model
Browse files- main.py +27 -13
- model-bin/pretrained/base/pytorch_model.bin +1 -1
main.py
CHANGED
@@ -45,6 +45,7 @@ def load_pretrained_model(checkpoint_path=None):
|
|
45 |
)
|
46 |
# model.freeze_feature_extractor()
|
47 |
|
|
|
48 |
model_total_params = sum(p.numel() for p in model.parameters())
|
49 |
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
50 |
print(model)
|
@@ -68,15 +69,19 @@ def prepare_dataset(batch, processor):
|
|
68 |
return batch
|
69 |
|
70 |
|
71 |
-
def load_prepared_dataset(path, processor,
|
72 |
dataset = load_from_disk(path)
|
|
|
|
|
|
|
|
|
73 |
processed_dataset = dataset.map(prepare_dataset,
|
74 |
remove_columns=dataset.column_names,
|
75 |
batch_size=32,
|
76 |
-
num_proc=
|
77 |
batched=True,
|
78 |
fn_kwargs={"processor": processor},
|
79 |
-
cache_file_name=
|
80 |
return processed_dataset
|
81 |
|
82 |
|
@@ -105,9 +110,9 @@ if __name__ == "__main__":
|
|
105 |
output_dir=checkpoint_path,
|
106 |
fp16=True,
|
107 |
group_by_length=True,
|
108 |
-
per_device_train_batch_size=
|
109 |
-
per_device_eval_batch_size=
|
110 |
-
gradient_accumulation_steps=
|
111 |
num_train_epochs=num_epochs, # each epoch per shard data
|
112 |
logging_steps=1,
|
113 |
learning_rate=1e-4,
|
@@ -150,17 +155,26 @@ if __name__ == "__main__":
|
|
150 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
151 |
'shard_{}'.format(train_dataset_shard_idx)),
|
152 |
w2v_ctc_processor,
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
157 |
# load test shard subset
|
158 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
159 |
'shard_{}'.format(test_dataset_shard_idx)),
|
160 |
w2v_ctc_processor,
|
161 |
-
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
164 |
)
|
165 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
166 |
# Init trainer
|
|
|
45 |
)
|
46 |
# model.freeze_feature_extractor()
|
47 |
|
48 |
+
# model = Wav2Vec2ForCTC(model.config)
|
49 |
model_total_params = sum(p.numel() for p in model.parameters())
|
50 |
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
51 |
print(model)
|
|
|
69 |
return batch
|
70 |
|
71 |
|
72 |
+
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=8):
|
73 |
dataset = load_from_disk(path)
|
74 |
+
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
|
75 |
+
batch_size=32,
|
76 |
+
num_proc=num_proc,
|
77 |
+
cache_file_name=cache_file_filter_name)
|
78 |
processed_dataset = dataset.map(prepare_dataset,
|
79 |
remove_columns=dataset.column_names,
|
80 |
batch_size=32,
|
81 |
+
num_proc=num_proc,
|
82 |
batched=True,
|
83 |
fn_kwargs={"processor": processor},
|
84 |
+
cache_file_name=cache_file_map_name)
|
85 |
return processed_dataset
|
86 |
|
87 |
|
|
|
110 |
output_dir=checkpoint_path,
|
111 |
fp16=True,
|
112 |
group_by_length=True,
|
113 |
+
per_device_train_batch_size=32,
|
114 |
+
per_device_eval_batch_size=32,
|
115 |
+
gradient_accumulation_steps=2,
|
116 |
num_train_epochs=num_epochs, # each epoch per shard data
|
117 |
logging_steps=1,
|
118 |
learning_rate=1e-4,
|
|
|
155 |
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
|
156 |
'shard_{}'.format(train_dataset_shard_idx)),
|
157 |
w2v_ctc_processor,
|
158 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
|
159 |
+
'train',
|
160 |
+
'cache-train-filter-shard-{}.arrow'.format(
|
161 |
+
train_dataset_shard_idx)),
|
162 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder,
|
163 |
+
'train',
|
164 |
+
'cache-train-map-shard-{}.arrow'.format(
|
165 |
+
train_dataset_shard_idx)),
|
166 |
+
) #.shard(1000, 0) # Remove shard split when train
|
167 |
# load test shard subset
|
168 |
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
|
169 |
'shard_{}'.format(test_dataset_shard_idx)),
|
170 |
w2v_ctc_processor,
|
171 |
+
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
|
172 |
+
'test',
|
173 |
+
'cache-test-filter-shard-{}.arrow'.format(
|
174 |
+
test_dataset_shard_idx)),
|
175 |
+
cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test',
|
176 |
+
'cache-test-map-shard-{}.arrow'.format(
|
177 |
+
test_dataset_shard_idx))
|
178 |
)
|
179 |
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
|
180 |
# Init trainer
|
model-bin/pretrained/base/pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 380261837
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b8fc5e67c00d407cd160a238034677db5670cbc77fe766c53d1042478509574d
|
3 |
size 380261837
|