|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \ |
|
TrainerCallback |
|
from datasets import load_from_disk |
|
from data_handler import DataCollatorCTCWithPadding |
|
from transformers import TrainingArguments |
|
from transformers import Trainer, logging |
|
from metric_utils import compute_metrics_fn |
|
from transformers.trainer_utils import get_last_checkpoint |
|
import json |
|
import os, glob |
|
from callbacks import BreakEachEpoch |
|
|
|
logging.set_verbosity_info() |
|
|
|
|
|
def load_pretrained_model(checkpoint_path=None): |
|
if checkpoint_path is None: |
|
pre_trained_path = './model-bin/pretrained/base' |
|
tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json", |
|
unk_token="<unk>", |
|
pad_token="<pad>", |
|
word_delimiter_token="|") |
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path) |
|
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
model = Wav2Vec2ForCTC.from_pretrained( |
|
pre_trained_path, |
|
gradient_checkpointing=True, |
|
ctc_loss_reduction="mean", |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
) |
|
model.freeze_feature_extractor() |
|
else: |
|
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path) |
|
|
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path) |
|
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
|
|
|
model = Wav2Vec2ForCTC.from_pretrained( |
|
checkpoint_path, |
|
gradient_checkpointing=True, |
|
ctc_loss_reduction="mean", |
|
pad_token_id=processor.tokenizer.pad_token_id, |
|
) |
|
|
|
|
|
|
|
model_total_params = sum(p.numel() for p in model.parameters()) |
|
model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print(model) |
|
print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params, |
|
model_total_params_trainable)) |
|
return model, processor |
|
|
|
|
|
def prepare_dataset(batch, processor): |
|
|
|
assert ( |
|
len(set(batch["sampling_rate"])) == 1 |
|
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}." |
|
|
|
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values |
|
|
|
batch["length"] = [len(item) for item in batch["input_values"]] |
|
|
|
with processor.as_target_processor(): |
|
batch["labels"] = processor(batch["target_text"]).input_ids |
|
return batch |
|
|
|
|
|
def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=8): |
|
dataset = load_from_disk(path) |
|
dataset = dataset.filter(lambda example: len(example['speech']) < 160000, |
|
batch_size=32, |
|
num_proc=num_proc, |
|
cache_file_name=cache_file_filter_name) |
|
processed_dataset = dataset.map(prepare_dataset, |
|
remove_columns=dataset.column_names, |
|
batch_size=32, |
|
num_proc=num_proc, |
|
batched=True, |
|
fn_kwargs={"processor": processor}, |
|
cache_file_name=cache_file_map_name) |
|
return processed_dataset |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
checkpoint_path = "./model-bin/finetune/base/" |
|
|
|
|
|
|
|
|
|
train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset' |
|
test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset' |
|
|
|
cache_processing_dataset_folder = './data-bin/cache/' |
|
if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')): |
|
os.makedirs(os.path.join(cache_processing_dataset_folder, 'train')) |
|
os.makedirs(os.path.join(cache_processing_dataset_folder, 'test')) |
|
num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*'))) |
|
num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*'))) |
|
num_epochs = 5000 |
|
|
|
training_args = TrainingArguments( |
|
output_dir=checkpoint_path, |
|
fp16=True, |
|
group_by_length=True, |
|
per_device_train_batch_size=32, |
|
per_device_eval_batch_size=32, |
|
gradient_accumulation_steps=2, |
|
num_train_epochs=num_epochs, |
|
logging_steps=1, |
|
learning_rate=1e-4, |
|
weight_decay=0.005, |
|
warmup_steps=1000, |
|
save_total_limit=2, |
|
ignore_data_skip=True, |
|
logging_dir=os.path.join(checkpoint_path, 'log'), |
|
metric_for_best_model='wer', |
|
save_strategy="epoch", |
|
evaluation_strategy="epoch", |
|
greater_is_better=False, |
|
|
|
|
|
) |
|
trainer = None |
|
|
|
|
|
last_checkpoint_path = None |
|
last_epoch_idx = 0 |
|
if os.path.exists(checkpoint_path): |
|
last_checkpoint_path = get_last_checkpoint(checkpoint_path) |
|
if last_checkpoint_path is not None: |
|
with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file: |
|
trainer_state = json.load(file) |
|
last_epoch_idx = int(trainer_state['epoch']) |
|
|
|
w2v_ctc_model, w2v_ctc_processor = load_pretrained_model() |
|
data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True) |
|
|
|
for epoch_idx in range(last_epoch_idx, num_epochs): |
|
|
|
train_dataset_shard_idx = epoch_idx % num_train_shards |
|
|
|
test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards)) |
|
num_test_sub_shard = 8 |
|
idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard |
|
|
|
|
|
train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder, |
|
'shard_{}'.format(train_dataset_shard_idx)), |
|
w2v_ctc_processor, |
|
cache_file_filter_name=os.path.join(cache_processing_dataset_folder, |
|
'train', |
|
'cache-train-filter-shard-{}.arrow'.format( |
|
train_dataset_shard_idx)), |
|
cache_file_map_name=os.path.join(cache_processing_dataset_folder, |
|
'train', |
|
'cache-train-map-shard-{}.arrow'.format( |
|
train_dataset_shard_idx)), |
|
) |
|
|
|
test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder, |
|
'shard_{}'.format(test_dataset_shard_idx)), |
|
w2v_ctc_processor, |
|
cache_file_filter_name=os.path.join(cache_processing_dataset_folder, |
|
'test', |
|
'cache-test-filter-shard-{}.arrow'.format( |
|
test_dataset_shard_idx)), |
|
cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test', |
|
'cache-test-map-shard-{}.arrow'.format( |
|
test_dataset_shard_idx)) |
|
) |
|
test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard) |
|
|
|
if trainer is None: |
|
trainer = Trainer( |
|
model=w2v_ctc_model, |
|
data_collator=data_collator, |
|
args=training_args, |
|
compute_metrics=compute_metrics_fn(w2v_ctc_processor), |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
tokenizer=w2v_ctc_processor.feature_extractor, |
|
callbacks=[BreakEachEpoch()] |
|
) |
|
else: |
|
trainer.train_dataset = train_dataset |
|
trainer.eval_dataset = test_dataset |
|
|
|
logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards)) |
|
logging.get_logger().info( |
|
'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard)) |
|
|
|
if last_checkpoint_path is not None: |
|
|
|
trainer.train(resume_from_checkpoint=True) |
|
else: |
|
|
|
trainer.train() |
|
last_checkpoint_path = get_last_checkpoint(checkpoint_path) |
|
|
|
|
|
test_dataset.cleanup_cache_files() |
|
train_dataset.cleanup_cache_files() |
|
|