import os import numpy as np import re import argparse os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"].replace("CUDA", "") from transformers import pipeline from transformers.models.whisper.english_normalizer import BasicTextNormalizer from datasets import load_dataset, Audio whisper_norm = BasicTextNormalizer() def simple_norm(utt): norm_utt = re.sub(r'[^\w\s]', '', utt) # remove punctualisation norm_utt = " ".join(norm_utt.split()) # remove whitespaces norm_utt = norm_utt.lower() return norm_utt def data(dataset): for i, item in enumerate(dataset): yield {**item["audio"], "reference": item["text"], "utt_id": item["id"]} def get_ckpt(path, ckpt_id): if ckpt_id != 0: model = os.path.join(path, "checkpoint-%i" % ckpt) else: dirs = [d for d in os.listdir(path) if d.startswith("checkpoint-")] ckpts = [int(d.split('-')[-1]) for d in dirs] last_ckpt = sorted(ckpts)[-1] model = os.path.join(path, "checkpoint-%s" % last_ckpt) return model def main(args): batch_size = args.batch_size if args.device == "cpu": device_id = -1 elif args.device == "gpu": device_id = 0 else: raise NotImplementedError("unknown device %s, should be cpu/gpu" % args.device) model_dir = os.path.join(args.expdir, args.model_size) #model = os.path.join(get_ckpt(model_dir, args.checkpoint), 'pytorch_model.bin') #model = get_ckpt(model_dir, args.checkpoint) model = model_dir #model = "openai/whisper-tiny" whisper_asr = pipeline( "automatic-speech-recognition", model=model, device=device_id ) whisper_asr.model.config.forced_decoder_ids = ( whisper_asr.tokenizer.get_decoder_prompt_ids( language=args.language, task="transcribe" ) ) if args.dataset == 'cgn-dev': dataset_path = "./cgn-dev/cgn-dev.py" elif args.dataset == 'subs-annot': dataset_path = "./subs-annot/subs-annot.py" else: raise NotImplementedError('unknown dataset %s' % args.dataset) cache_dir = "/esat/audioslave/jponcele/hf_cache" dataset = load_dataset(dataset_path, name="raw", split="test", cache_dir=cache_dir, streaming=True) dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) utterances = [] predictions = [] references = [] # run streamed inference for out in whisper_asr(data(dataset), batch_size=batch_size): predictions.append(out["text"]) utterances.append(out["utt_id"][0]) references.append(out["reference"][0]) #break result_dir = os.path.join(args.expdir, "results", args.dataset) os.makedirs(result_dir, exist_ok=True) with open(os.path.join(result_dir, "whisper_%s.txt" % args.model_size), "w") as pd: for i, utt in enumerate(utterances): pred = predictions[i] pd.write(utt + ' ' + pred + '\n') with open(os.path.join(result_dir, "whisper_%s_normW.txt" % args.model_size), "w") as pd: for i, utt in enumerate(utterances): pred = whisper_norm(predictions[i]) pd.write(utt + ' ' + pred + '\n') with open(os.path.join(result_dir, "whisper_%s_normS.txt" % args.model_size), "w") as pd: for i, utt in enumerate(utterances): pred = simple_norm(predictions[i]) pd.write(utt + ' ' + pred + '\n') if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--expdir", type=str, default="/esat/audioslave/jponcele/whisper/finetuning_event/CGN", help="Directory with finetuned models", ) parser.add_argument( "--model_size", type=str, default="tiny", help="Model size", ) parser.add_argument( "--checkpoint", type=int, default=0, help="Load specific checkpoint. 0 means latest", ) parser.add_argument( "--dataset", type=str, default="cgn-dev", help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets", ) parser.add_argument( "--device", type=str, default="cpu", help="cpu/gpu", ) parser.add_argument( "--batch_size", type=int, default=16, help="Number of samples to go through each streamed batch.", ) parser.add_argument( "--language", type=str, default="dutch", help="Two letter language code for the transcription language, e.g. use 'en' for English.", ) args = parser.parse_args() main(args)