File size: 10,405 Bytes
1e275bf
 
778e524
 
 
 
 
 
1e275bf
778e524
1e275bf
778e524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb2b82e
778e524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bbf33
 
778e524
 
 
 
 
cb2b82e
778e524
cb2b82e
 
 
 
778e524
 
1e275bf
cb2b82e
778e524
 
cb2b82e
778e524
 
 
 
 
 
 
 
 
cae4858
 
 
 
 
 
 
778e524
1e275bf
 
 
778e524
 
b839dd6
778e524
 
 
32440c9
778e524
cb2b82e
 
 
1e275bf
778e524
 
 
1e275bf
778e524
 
 
 
 
 
1e275bf
778e524
 
 
cbf9056
778e524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58c3693
778e524
 
 
 
 
 
cb2b82e
 
 
 
 
 
 
 
 
778e524
 
 
 
cb2b82e
 
 
 
 
 
 
1e275bf
 
778e524
cbf9056
 
 
 
 
 
 
 
 
 
 
 
 
 
778e524
1e275bf
 
 
778e524
 
 
 
 
 
 
 
 
 
1e275bf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, \
    TrainerCallback
from datasets import load_from_disk
from data_handler import DataCollatorCTCWithPadding
from transformers import TrainingArguments
from transformers import Trainer, logging
from metric_utils import compute_metrics_fn
from transformers.trainer_utils import get_last_checkpoint
import json
import os, glob
from callbacks import BreakEachEpoch

logging.set_verbosity_info()


def load_pretrained_model(checkpoint_path=None):
    if checkpoint_path is None:
        pre_trained_path = './model-bin/pretrained/base'
        tokenizer = Wav2Vec2CTCTokenizer("./model-bin/finetune/vocab.json",
                                         unk_token="<unk>",
                                         pad_token="<pad>",
                                         word_delimiter_token="|")

        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pre_trained_path)
        processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

        model = Wav2Vec2ForCTC.from_pretrained(
            pre_trained_path,
            gradient_checkpointing=True,
            ctc_loss_reduction="mean",
            pad_token_id=processor.tokenizer.pad_token_id,
        )
        model.freeze_feature_extractor()
    else:
        tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(checkpoint_path)

        feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint_path)
        processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

        model = Wav2Vec2ForCTC.from_pretrained(
            checkpoint_path,
            gradient_checkpointing=True,
            ctc_loss_reduction="mean",
            pad_token_id=processor.tokenizer.pad_token_id,
        )
        # model.freeze_feature_extractor()

    # model = Wav2Vec2ForCTC(model.config)
    model_total_params = sum(p.numel() for p in model.parameters())
    model_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(model)
    print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
                                                                            model_total_params_trainable))
    return model, processor


def prepare_dataset(batch, processor):
    # check that all files have the correct sampling rate
    assert (
            len(set(batch["sampling_rate"])) == 1
    ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."

    batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values

    batch["length"] = [len(item) for item in batch["input_values"]]

    with processor.as_target_processor():
        batch["labels"] = processor(batch["target_text"]).input_ids
    return batch


def load_prepared_dataset(path, processor, cache_file_filter_name, cache_file_map_name, num_proc=8):
    dataset = load_from_disk(path)
    dataset = dataset.filter(lambda example: len(example['speech']) < 160000,
                             batch_size=32,
                             num_proc=num_proc,
                             cache_file_name=cache_file_filter_name)
    processed_dataset = dataset.map(prepare_dataset,
                                    remove_columns=dataset.column_names,
                                    batch_size=32,
                                    num_proc=num_proc,
                                    batched=True,
                                    fn_kwargs={"processor": processor},
                                    cache_file_name=cache_file_map_name)
    return processed_dataset


# def get_train_dataset():
#     for i in range()

if __name__ == "__main__":

    checkpoint_path = "./model-bin/finetune/base/"

    # train_dataset_root_folder = './data-bin/train_dataset'
    # test_dataset_root_folder = './data-bin/test_dataset'

    train_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/train_dataset'
    test_dataset_root_folder = '/content/drive/MyDrive/audio_dataset/test_dataset'

    cache_processing_dataset_folder = './data-bin/cache/'
    if not os.path.exists(os.path.join(cache_processing_dataset_folder, 'train')):
        os.makedirs(os.path.join(cache_processing_dataset_folder, 'train'))
        os.makedirs(os.path.join(cache_processing_dataset_folder, 'test'))
    num_train_shards = len(glob.glob(os.path.join(train_dataset_root_folder, 'shard_*')))
    num_test_shards = len(glob.glob(os.path.join(test_dataset_root_folder, 'shard_*')))
    num_epochs = 5000

    training_args = TrainingArguments(
        output_dir=checkpoint_path,
        fp16=True,
        group_by_length=True,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        gradient_accumulation_steps=2,
        num_train_epochs=num_epochs,  # each epoch per shard data
        logging_steps=1,
        learning_rate=1e-4,
        weight_decay=0.005,
        warmup_steps=1000,
        save_total_limit=2,
        ignore_data_skip=True,
        logging_dir=os.path.join(checkpoint_path, 'log'),
        metric_for_best_model='wer',
        save_strategy="epoch",
        evaluation_strategy="epoch",
        greater_is_better=False,
        # save_steps=5,
        # eval_steps=5,
    )
    trainer = None

    # PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
    last_checkpoint_path = None
    last_epoch_idx = 0
    if os.path.exists(checkpoint_path):
        last_checkpoint_path = get_last_checkpoint(checkpoint_path)
        if last_checkpoint_path is not None:
            with open(os.path.join(last_checkpoint_path, "trainer_state.json"), 'r', encoding='utf-8') as file:
                trainer_state = json.load(file)
                last_epoch_idx = int(trainer_state['epoch'])

    w2v_ctc_model, w2v_ctc_processor = load_pretrained_model()
    data_collator = DataCollatorCTCWithPadding(processor=w2v_ctc_processor, padding=True)

    for epoch_idx in range(last_epoch_idx, num_epochs):
        # loop over training shards
        train_dataset_shard_idx = epoch_idx % num_train_shards
        # Get test shard depend on train shard id
        test_dataset_shard_idx = round(train_dataset_shard_idx / (num_train_shards / num_test_shards))
        num_test_sub_shard = 8  # Split test shard into subset. Default is 8
        idx_sub_shard = train_dataset_shard_idx % num_test_sub_shard  # loop over test shard subset

        # load train shard
        train_dataset = load_prepared_dataset(os.path.join(train_dataset_root_folder,
                                                           'shard_{}'.format(train_dataset_shard_idx)),
                                              w2v_ctc_processor,
                                              cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
                                                                                  'train',
                                                                                  'cache-train-filter-shard-{}.arrow'.format(
                                                                                      train_dataset_shard_idx)),
                                              cache_file_map_name=os.path.join(cache_processing_dataset_folder,
                                                                               'train',
                                                                               'cache-train-map-shard-{}.arrow'.format(
                                                                                   train_dataset_shard_idx)),
                                              ) #.shard(1000, 0)  # Remove shard split when train
        # load test shard subset
        test_dataset = load_prepared_dataset(os.path.join(test_dataset_root_folder,
                                                          'shard_{}'.format(test_dataset_shard_idx)),
                                             w2v_ctc_processor,
                                             cache_file_filter_name=os.path.join(cache_processing_dataset_folder,
                                                                                 'test',
                                                                                 'cache-test-filter-shard-{}.arrow'.format(
                                                                                     test_dataset_shard_idx)),
                                             cache_file_map_name=os.path.join(cache_processing_dataset_folder, 'test',
                                                                              'cache-test-map-shard-{}.arrow'.format(
                                                                                  test_dataset_shard_idx))
                                             )
        test_dataset = test_dataset.shard(num_test_sub_shard, idx_sub_shard)
        # Init trainer
        if trainer is None:
            trainer = Trainer(
                model=w2v_ctc_model,
                data_collator=data_collator,
                args=training_args,
                compute_metrics=compute_metrics_fn(w2v_ctc_processor),
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                tokenizer=w2v_ctc_processor.feature_extractor,
                callbacks=[BreakEachEpoch()]  # Manual break end of epoch because each epoch loop over a shard
            )
        else:
            trainer.train_dataset = train_dataset
            trainer.eval_dataset = test_dataset

        logging.get_logger().info('Train shard idx: {} / {}'.format(train_dataset_shard_idx + 1, num_train_shards))
        logging.get_logger().info(
            'Valid shard idx: {} / {} sub_shard: {}'.format(test_dataset_shard_idx + 1, num_test_shards, idx_sub_shard))

        if last_checkpoint_path is not None:
            # start train from a checkpoint if exist
            trainer.train(resume_from_checkpoint=True)
        else:
            # train from pre-trained wav2vec2 checkpoint
            trainer.train()
        last_checkpoint_path = get_last_checkpoint(checkpoint_path)

        # Clear cache file to free disk
        test_dataset.cleanup_cache_files()
        train_dataset.cleanup_cache_files()