File size: 10,405 Bytes
1e275bf 778e524 1e275bf 778e524 1e275bf 778e524 cb2b82e 778e524 f1bbf33 778e524 cb2b82e 778e524 cb2b82e 778e524 1e275bf cb2b82e 778e524 cb2b82e 778e524 cae4858 778e524 1e275bf 778e524 b839dd6 778e524 32440c9 778e524 cb2b82e 1e275bf 778e524 1e275bf 778e524 1e275bf 778e524 cbf9056 778e524 58c3693 778e524 cb2b82e 778e524 cb2b82e 1e275bf 778e524 cbf9056 778e524 1e275bf 778e524 1e275bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \
TrainerCallback
from datasets import load_from_disk
from data_handler import DataCollatorCTCWithPadding
from transformers import TrainingArguments
from transformers import Trainer, logging
from metric_utils import compute_metrics_fn
from transformers.trainer_utils import get_last_checkpoint
import json
import os, glob
from callbacks import BreakEachEpoch
logging.set_verbosity_info()
def load_pretrained_model(checkpoint_path=None):
if checkpoint_path is None:
pre_trained_path = './model-bin/pretrained/base'
tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json",
unk_token="<unk>",
pad_token="<pad>",
word_delimiter_token="|")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained(
pre_trained_path,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
)
model.freeze_feature_extractor()
else:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = Wav2Vec2ForCTC.from_pretrained(
checkpoint_path,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
)
# model.freeze_feature_extractor()
# model = Wav2Vec2ForCTC(model.config)
model_total_params = sum(p.numel() for p in model.parameters())
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model)
print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
model_total_params_trainable))
return model, processor
def prepare_dataset(batch, processor):
# check that all files have the correct sampling rate
assert (
len(set(batch["sampling_rate"])) == 1
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
batch["length"] = [len(item) for item in batch["input_values"]]
with processor.as_target_processor():
batch["labels"] = processor(batch["target_text"]).input_ids
return batch
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=8):
dataset = load_from_disk(path)
dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
batch_size=32,
num_proc=num_proc,
cache_file_name=cache_file_filter_name)
processed_dataset = dataset.map(prepare_dataset,
remove_columns=dataset.column_names,
batch_size=32,
num_proc=num_proc,
batched=True,
fn_kwargs={"processor": processor},
cache_file_name=cache_file_map_name)
return processed_dataset
# def get_train_dataset():
# for i in range()
if __name__ == "__main__":
checkpoint_path = "./model-bin/finetune/base/"
# train_dataset_root_folder = './data-bin/train_dataset'
# test_dataset_root_folder = './data-bin/test_dataset'
train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'
cache_processing_dataset_folder = './data-bin/cache/'
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
num_epochs = 5000
training_args = TrainingArguments(
output_dir=checkpoint_path,
fp16=True,
group_by_length=True,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
gradient_accumulation_steps=2,
num_train_epochs=num_epochs, # each epoch per shard data
logging_steps=1,
learning_rate=1e-4,
weight_decay=0.005,
warmup_steps=1000,
save_total_limit=2,
ignore_data_skip=True,
logging_dir=os.path.join(checkpoint_path, 'log'),
metric_for_best_model='wer',
save_strategy="epoch",
evaluation_strategy="epoch",
greater_is_better=False,
# save_steps=5,
# eval_steps=5,
)
trainer = None
# PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
last_checkpoint_path = None
last_epoch_idx = 0
if os.path.exists(checkpoint_path):
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
if last_checkpoint_path is not None:
with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file:
trainer_state = json.load(file)
last_epoch_idx = int(trainer_state['epoch'])
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)
for epoch_idx in range(last_epoch_idx, num_epochs):
# loop over training shards
train_dataset_shard_idx = epoch_idx % num_train_shards
# Get test shard depend on train shard id
test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
num_test_sub_shard = 8 # Split test shard into subset. Default is 8
idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard # loop over test shard subset
# load train shard
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
'shard_{}'.format(train_dataset_shard_idx)),
w2v_ctc_processor,
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
'train',
'cache-train-filter-shard-{}.arrow'.format(
train_dataset_shard_idx)),
cache_file_map_name=os.path.join(cache_processing_dataset_folder,
'train',
'cache-train-map-shard-{}.arrow'.format(
train_dataset_shard_idx)),
) #.shard(1000, 0) # Remove shard split when train
# load test shard subset
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
'shard_{}'.format(test_dataset_shard_idx)),
w2v_ctc_processor,
cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
'test',
'cache-test-filter-shard-{}.arrow'.format(
test_dataset_shard_idx)),
cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test',
'cache-test-map-shard-{}.arrow'.format(
test_dataset_shard_idx))
)
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
# Init trainer
if trainer is None:
trainer = Trainer(
model=w2v_ctc_model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics_fn(w2v_ctc_processor),
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=w2v_ctc_processor.feature_extractor,
callbacks=[BreakEachEpoch()] # Manual break end of epoch because each epoch loop over a shard
)
else:
trainer.train_dataset = train_dataset
trainer.eval_dataset = test_dataset
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
logging.get_logger().info(
'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard))
if last_checkpoint_path is not None:
# start train from a checkpoint if exist
trainer.train(resume_from_checkpoint=True)
else:
# train from pre-trained wav2vec2 checkpoint
trainer.train()
last_checkpoint_path = get_last_checkpoint(checkpoint_path)
# Clear cache file to free disk
test_dataset.cleanup_cache_files()
train_dataset.cleanup_cache_files()
|