File size: 7,886 Bytes
37c6f5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from datasets import Audio, interleave_datasets, IterableDataset, load_dataset
from typing import List, Optional
dataset_names = ["mozilla-foundation/common_voice_11_0", "google/fleurs"]
dataset_config_names = ["da", "da_dk"]
text_column_names = ["sentence", "normalized_text", "text", "transcription"]
from datasets import interleave_datasets, load_dataset
def load_streaming_dataset(dataset_name, dataset_config_name, split, **kwargs):
if "+" in split:
# load multiple splits separated by the `+` symbol *with* streaming mode
dataset_splits = [load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs) for split_name in split.split("+")]
# interleave multiple splits to form one dataset
interleaved_dataset = interleave_datasets(dataset_splits)
return interleaved_dataset
else:
# load a single split *with* streaming mode
dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)
return dataset
from datasets import IterableDatasetDict
raw_datasets = IterableDatasetDict()
raw_datasets["train"] = load_streaming_dataset("mozilla-foundation/common_voice_11_0", "da", split="train+validation", use_auth_token=True) # set split="train+validation" for low-resource
raw_datasets["test"] = load_streaming_dataset("mozilla-foundation/common_voice_11_0", "da", split="test", use_auth_token=True)
from transformers import WhisperProcessor
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Danish", task="transcribe")
from datasets import Audio
raw_datasets = raw_datasets.cast_column("audio", Audio(sampling_rate=16000))
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()
def prepare_dataset(batch):
# load and (possibly) resample audio data to 16kHz
audio = batch["audio"]
# compute log-Mel input features from input audio array
batch["input_features"] = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
# compute input length of audio sample in seconds
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
# optional pre-processing steps
transcription = batch["sentence"]
if do_lower_case:
transcription = transcription.lower()
if do_remove_punctuation:
transcription = normalizer(transcription).strip()
# encode target text to label ids
batch["labels"] = processor.tokenizer(transcription).input_ids
return batch
vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=list(next(iter(raw_datasets.values())).features)).with_format("torch")
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
buffer_size=500,
seed=0,
)
max_input_length = 30.0
def is_audio_in_length_range(length):
return length < max_input_length
vectorized_datasets["train"] = vectorized_datasets["train"].filter(
is_audio_in_length_range,
input_columns=["input_length"],
)
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
import evaluate
metric = evaluate.load("wer")
# evaluate with the 'normalised' WER
do_normalize_eval = True
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
if do_normalize_eval:
pred_str = [normalizer(pred) for pred in pred_str]
label_str = [normalizer(label) for label in label_str]
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
from transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []
model.config.use_cache = False
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./",
per_device_train_batch_size=64,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-07,
warmup_steps=500,
max_steps=5000,
gradient_checkpointing=True,
fp16=True,
evaluation_strategy="steps",
per_device_eval_batch_size=32,
predict_with_generate=True,
generation_max_length=225,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
#optim="adamw_bnb_8bit"
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import IterableDatasetShard
from torch.utils.data import IterableDataset
# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
class ShuffleCallback(TrainerCallback):
def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
if isinstance(train_dataloader.dataset, IterableDatasetShard):
pass # set_epoch() is handled by the Trainer
elif isinstance(train_dataloader.dataset, IterableDataset):
train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=vectorized_datasets["train"],
eval_dataset=vectorized_datasets["test"],
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor,
callbacks=[ShuffleCallback()],
)
model.save_pretrained(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)
trainer.train()
kwargs = {
"dataset_tags": "mozilla-foundation/common_voice_11_0",
"dataset": "Common Voice 11.0, FLEURS", # a 'pretty' name for the training dataset
"language": "da",
"model_name": "Whisper Small da - Common Voice+FLEURS", # a 'pretty' name for your model
"finetuned_from": "openai/whisper-small",
"tasks": "automatic-speech-recognition",
"tags": "whisper-event",
}
trainer.push_to_hub(**kwargs)
|