""" This script provides an example to wrap TencentPretrain for speech-to-text inference. """ import sys import os import tqdm import argparse import math import torch import torch.nn.functional as F tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(tencentpretrain_dir) from tencentpretrain.utils.constants import * from tencentpretrain.utils import * from tencentpretrain.utils.config import load_hyperparam from tencentpretrain.model_loader import load_model from tencentpretrain.opts import infer_opts, tokenizer_opts from finetune.run_speech2text import Speech2text, read_dataset from inference.run_classifier_infer import batch_loader def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) infer_opts(parser) tokenizer_opts(parser) parser.add_argument("--beam_width", type=int, default=10, help="Beam width.") parser.add_argument("--tgt_seq_length", type=int, default=100, help="inference step.") args = parser.parse_args() # Load the hyperparameters from the config file. args = load_hyperparam(args) # Build tokenizer. args.tokenizer = str2tokenizer[args.tokenizer](args) # Build s2t model. model = Speech2text(args) model = load_model(model, args.load_model_path) args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(args.device) if torch.cuda.device_count() > 1: print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) dataset = read_dataset(args, args.test_path) src_audio = torch.stack([sample[0] for sample in dataset], dim=0) seg_audio = torch.LongTensor([sample[3] for sample in dataset]) batch_size = args.batch_size beam_width=args.beam_width instances_num = src_audio.size()[0] tgt_seq_length = args.tgt_seq_length PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN]) SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN]) CLS_ID = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN]) print("The number of prediction instances: ", instances_num) model.eval() with open(args.prediction_path, mode="w", encoding="utf-8") as f: for i, (src_batch, seg_batch) in tqdm.tqdm(enumerate(batch_loader(batch_size, src_audio, seg_audio))): src_batch = src_batch.to(args.device) seg_batch = seg_batch.to(args.device) seq_length = seg_batch.sum(dim=-1).max() src_batch = src_batch[:,:seq_length * 4,:] seg_batch = seg_batch[:,:seq_length] tgt_in_batch = torch.zeros(src_batch.size()[0], 1, dtype = torch.long, device = args.device) current_batch_size=tgt_in_batch.size()[0] for j in range(current_batch_size): tgt_in_batch[j][0] = torch.LongTensor(SEP_ID) #torch.LongTensor(CLS_ID) for tencentpretrain model with torch.no_grad(): memory_bank, emb = model(src_batch, None, seg_batch, None, only_use_encoder=True) step = 0 scores = torch.zeros(current_batch_size, beam_width, tgt_seq_length) tokens = torch.zeros(current_batch_size, beam_width, tgt_seq_length+1, dtype = torch.long) tokens[:,:,0] = torch.LongTensor(args.tokenizer.convert_tokens_to_ids([SEP_TOKEN])) #2 emb = emb.repeat(1, beam_width, 1).reshape(current_batch_size * beam_width, -1, int(args.conv_channels[-1] / 2)) #same batch nearby memory_bank = memory_bank.repeat(1, beam_width, 1).reshape(current_batch_size * beam_width, -1, args.emb_size) tgt_in_batch = tgt_in_batch.repeat(beam_width, 1) while step < tgt_seq_length and step < seq_length: with torch.no_grad(): outputs = model(emb, (tgt_in_batch, None, None), None, None, memory_bank=memory_bank) vocab_size = outputs.shape[-1] log_prob = F.log_softmax(outputs[:, -1, :], dim=-1) #(B*beam_size) * 1 * vocab_size log_prob = log_prob.squeeze() #(B*beam_size) * vocab_size log_prob[:,PAD_ID] = -math.inf # do not select pad if step == 0: log_prob[:,SEP_ID] = -math.inf # log_prob_beam = log_prob.reshape(current_batch_size, beam_width, -1) # B * beam * vocab_size if step == 0: log_prob_beam = log_prob_beam[:, ::beam_width, :].contiguous().to(scores.device) else: log_prob_beam = log_prob_beam.to(scores.device) + scores[:,:, step-1].unsqueeze(-1) top_prediction_prob, top_prediction_indices = torch.topk(log_prob_beam.view(current_batch_size, -1), k=beam_width) beams_buf = torch.div(top_prediction_indices, vocab_size).trunc().long() beams_buf = beams_buf + torch.arange(current_batch_size).repeat(beam_width).reshape(beam_width,-1).transpose(0,1) * beam_width top_prediction_indices = top_prediction_indices.fmod(vocab_size) scores[:, :, step] = top_prediction_prob tokens[:, :, step+1] = top_prediction_indices if step > 0 and current_batch_size == 1: tokens[:, :, :step+1] = torch.index_select(tokens, dim=1, index=beams_buf.squeeze())[:, :, :step+1] elif step > 0: tokens[:, :, step+1] = torch.index_select(tokens.reshape(-1,tokens.shape[2]), dim=0, index=beams_buf.reshape(-1)).reshape(current_batch_size, -1, tokens.shape[2])[:, :, step+1] tgt_in_batch = tokens[:, :, :step+2].view(current_batch_size * beam_width, -1) tgt_in_batch = tgt_in_batch.long().to(emb.device) step = step + 1 for i in range(current_batch_size): for j in range(1): res = "".join([args.tokenizer.inv_vocab[token_id.item()] for token_id in tokens[i,j,:]]) res = res.split(SEP_TOKEN)[1].split(CLS_TOKEN)[0] # res.split(CLS_TOKEN)[1].split(SEP_TOKEN)[0] for tencentpretrain model res = res.replace('▁',' ') f.write(res) f.write("\n") if __name__ == "__main__": main()