MMS / infer.py
GreenRaptor's picture
Create infer.py
52fe7b2
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Run inference for pre-processed data with a trained model.
"""
import ast
import logging
import math
import os
import sys
import editdistance
import numpy as np
import torch
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.data.data_utils import post_process
from fairseq.logging.meters import StopwatchMeter, TimeMeter
logging.basicConfig()
logging.root.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def add_asr_eval_argument(parser):
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
parser.add_argument(
"--wfstlm", default=None, help="wfstlm on dictonary output units"
)
parser.add_argument(
"--rnnt_decoding_type",
default="greedy",
help="wfstlm on dictonary\
output units",
)
try:
parser.add_argument(
"--lm-weight",
"--lm_weight",
type=float,
default=0.2,
help="weight for lm while interpolating with neural score",
)
except:
pass
parser.add_argument(
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
)
parser.add_argument(
"--w2l-decoder",
choices=["viterbi", "kenlm", "fairseqlm"],
help="use a w2l decoder",
)
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
parser.add_argument("--beam-threshold", type=float, default=25.0)
parser.add_argument("--beam-size-token", type=float, default=100)
parser.add_argument("--word-score", type=float, default=1.0)
parser.add_argument("--unk-weight", type=float, default=-math.inf)
parser.add_argument("--sil-weight", type=float, default=0.0)
parser.add_argument(
"--dump-emissions",
type=str,
default=None,
help="if present, dumps emissions into this file and exits",
)
parser.add_argument(
"--dump-features",
type=str,
default=None,
help="if present, dumps features into this file and exits",
)
parser.add_argument(
"--load-emissions",
type=str,
default=None,
help="if present, loads emissions from this file",
)
return parser
def check_args(args):
# assert args.path is not None, "--path required for generation!"
# assert args.results_path is not None, "--results_path required for generation!"
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
args.replace_unk is None or args.raw_text
), "--replace-unk requires a raw text dataset (--raw-text)"
def get_dataset_itr(args, task, models):
return task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=(sys.maxsize, sys.maxsize),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
data_buffer_size=args.data_buffer_size,
).next_epoch_itr(shuffle=False)
def process_predictions(
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
):
for hypo in hypos[: min(len(hypos), args.nbest)]:
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
if "words" in hypo:
hyp_words = " ".join(hypo["words"])
else:
hyp_words = post_process(hyp_pieces, args.post_process)
if res_files is not None:
print(
"{} ({}-{})".format(hyp_pieces, speaker, id),
file=res_files["hypo.units"],
)
print(
"{} ({}-{})".format(hyp_words, speaker, id),
file=res_files["hypo.words"],
)
tgt_pieces = tgt_dict.string(target_tokens)
tgt_words = post_process(tgt_pieces, args.post_process)
if res_files is not None:
print(
"{} ({}-{})".format(tgt_pieces, speaker, id),
file=res_files["ref.units"],
)
print(
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
)
if not args.quiet:
logger.info("HYPO:" + hyp_words)
logger.info("TARGET:" + tgt_words)
logger.info("___________________")
hyp_words = hyp_words.split()
tgt_words = tgt_words.split()
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def prepare_result_files(args):
def get_res_file(file_prefix):
if args.num_shards > 1:
file_prefix = f"{args.shard_id}_{file_prefix}"
path = os.path.join(
args.results_path,
"{}-{}-{}.txt".format(
file_prefix, os.path.basename(args.path), args.gen_subset
),
)
return open(path, "w", buffering=1)
if not args.results_path:
return None
return {
"hypo.words": get_res_file("hypo.word"),
"hypo.units": get_res_file("hypo.units"),
"ref.words": get_res_file("ref.word"),
"ref.units": get_res_file("ref.units"),
}
def optimize_models(args, use_cuda, models):
"""Optimize ensemble for generation"""
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
def apply_half(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.half)
return t
class ExistingEmissionsDecoder(object):
def __init__(self, decoder, emissions):
self.decoder = decoder
self.emissions = emissions
def generate(self, models, sample, **unused):
ids = sample["id"].cpu().numpy()
try:
emissions = np.stack(self.emissions[ids])
except:
print([x.shape for x in self.emissions[ids]])
raise Exception("invalid sizes")
emissions = torch.from_numpy(emissions)
return self.decoder.decode(emissions)
def main(args, task=None, model_state=None):
check_args(args)
use_fp16 = args.fp16
if args.max_tokens is None and args.batch_size is None:
args.max_tokens = 4000000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
logger.info("| decoding with criterion {}".format(args.criterion))
task = tasks.setup_task(args)
# Load ensemble
if args.load_emissions:
models, criterions = [], []
task.load_dataset(args.gen_subset)
else:
logger.info("| loading model(s) from {}".format(args.path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
utils.split_paths(args.path, separator="\\"),
arg_overrides=ast.literal_eval(args.model_overrides),
task=task,
suffix=args.checkpoint_suffix,
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
state=model_state,
)
optimize_models(args, use_cuda, models)
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task)
# Set dictionary
tgt_dict = task.target_dictionary
logger.info(
"| {} {} {} examples".format(
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
)
)
# hack to pass transitions to W2lDecoder
if args.criterion == "asg_loss":
raise NotImplementedError("asg_loss is currently not supported")
# trans = criterions[0].asg.trans.data
# args.asg_transitions = torch.flatten(trans).tolist()
# Load dataset (possibly sharded)
itr = get_dataset_itr(args, task, models)
# Initialize generator
gen_timer = StopwatchMeter()
def build_generator(args):
w2l_decoder = getattr(args, "w2l_decoder", None)
if w2l_decoder == "viterbi":
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
return W2lViterbiDecoder(args, task.target_dictionary)
elif w2l_decoder == "kenlm":
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
return W2lKenLMDecoder(args, task.target_dictionary)
elif w2l_decoder == "fairseqlm":
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
return W2lFairseqLMDecoder(args, task.target_dictionary)
else:
print(
"only flashlight decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
)
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
generator = build_generator(args)
if args.load_emissions:
generator = ExistingEmissionsDecoder(
generator, np.load(args.load_emissions, allow_pickle=True)
)
logger.info("loaded emissions from " + args.load_emissions)
num_sentences = 0
if args.results_path is not None and not os.path.exists(args.results_path):
os.makedirs(args.results_path)
max_source_pos = (
utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
),
)
if max_source_pos is not None:
max_source_pos = max_source_pos[0]
if max_source_pos is not None:
max_source_pos = max_source_pos[0] - 1
if args.dump_emissions:
emissions = {}
if args.dump_features:
features = {}
models[0].bert.proj = None
else:
res_files = prepare_result_files(args)
errs_t = 0
lengths_t = 0
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for sample in t:
sample = utils.move_to_cuda(sample) if use_cuda else sample
if use_fp16:
sample = utils.apply_to_sample(apply_half, sample)
if "net_input" not in sample:
continue
prefix_tokens = None
if args.prefix_size > 0:
prefix_tokens = sample["target"][:, : args.prefix_size]
gen_timer.start()
if args.dump_emissions:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
emm = emm.transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
emissions[id.item()] = emm[i]
continue
elif args.dump_features:
with torch.no_grad():
encoder_out = models[0](**sample["net_input"])
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
for i, id in enumerate(sample["id"]):
padding = (
encoder_out["encoder_padding_mask"][i].cpu().numpy()
if encoder_out["encoder_padding_mask"] is not None
else None
)
features[id.item()] = (feat[i], padding)
continue
hypos = task.inference_step(generator, models, sample, prefix_tokens)
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
gen_timer.stop(num_generated_tokens)
for i, sample_id in enumerate(sample["id"].tolist()):
speaker = None
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
id = sample_id
toks = (
sample["target"][i, :]
if "target_label" not in sample
else sample["target_label"][i, :]
)
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
# Process top predictions
errs, length = process_predictions(
args,
hypos[i],
None,
tgt_dict,
target_tokens,
res_files,
speaker,
id,
)
errs_t += errs
lengths_t += length
wps_meter.update(num_generated_tokens)
t.log({"wps": round(wps_meter.avg)})
num_sentences += (
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
)
wer = None
if args.dump_emissions:
emm_arr = []
for i in range(len(emissions)):
emm_arr.append(emissions[i])
np.save(args.dump_emissions, emm_arr)
logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
elif args.dump_features:
feat_arr = []
for i in range(len(features)):
feat_arr.append(features[i])
np.save(args.dump_features, feat_arr)
logger.info(f"saved {len(features)} emissions to {args.dump_features}")
else:
if lengths_t > 0:
wer = errs_t * 100.0 / lengths_t
logger.info(f"WER: {wer}")
logger.info(
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
"sentences/s, {:.2f} tokens/s)".format(
num_sentences,
gen_timer.n,
gen_timer.sum,
num_sentences / gen_timer.sum,
1.0 / gen_timer.avg,
)
)
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
return task, wer
def make_parser():
parser = options.get_generation_parser()
parser = add_asr_eval_argument(parser)
return parser
def cli_main():
parser = make_parser()
args = options.parse_args_and_arch(parser)
main(args)
if __name__ == "__main__":
cli_main()