File size: 6,446 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
This script provides an example to wrap TencentPretrain for speech-to-text inference.
"""
import sys
import os
import tqdm
import argparse
import math
import torch
import torch.nn.functional as F

tencentpretrain_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(tencentpretrain_dir)

from tencentpretrain.utils.constants import *
from tencentpretrain.utils import *
from tencentpretrain.utils.config import load_hyperparam
from tencentpretrain.model_loader import load_model
from tencentpretrain.opts import infer_opts, tokenizer_opts
from finetune.run_speech2text import Speech2text, read_dataset
from inference.run_classifier_infer import batch_loader


def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    infer_opts(parser)

    tokenizer_opts(parser)

    parser.add_argument("--beam_width", type=int, default=10,
                        help="Beam width.")
    parser.add_argument("--tgt_seq_length", type=int, default=100,
                        help="inference step.")
    args = parser.parse_args()

    # Load the hyperparameters from the config file.
    args = load_hyperparam(args)

    # Build tokenizer.
    args.tokenizer = str2tokenizer[args.tokenizer](args)

    # Build s2t model.
    model = Speech2text(args)
    model = load_model(model, args.load_model_path)

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(args.device)

    if torch.cuda.device_count() > 1:
        print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)

    dataset = read_dataset(args, args.test_path)

    src_audio = torch.stack([sample[0] for sample in dataset], dim=0)
    seg_audio = torch.LongTensor([sample[3] for sample in dataset])

    batch_size = args.batch_size
    beam_width=args.beam_width
    instances_num = src_audio.size()[0]
    tgt_seq_length = args.tgt_seq_length

    PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])
    SEP_ID = args.tokenizer.convert_tokens_to_ids([SEP_TOKEN])
    CLS_ID = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN])
    print("The number of prediction instances: ", instances_num)

    model.eval()
    with open(args.prediction_path, mode="w", encoding="utf-8") as f:
        for i, (src_batch, seg_batch) in tqdm.tqdm(enumerate(batch_loader(batch_size, src_audio, seg_audio))):
            src_batch = src_batch.to(args.device)
            seg_batch = seg_batch.to(args.device)

            seq_length = seg_batch.sum(dim=-1).max()
            src_batch = src_batch[:,:seq_length * 4,:]
            seg_batch = seg_batch[:,:seq_length]

            tgt_in_batch = torch.zeros(src_batch.size()[0], 1, dtype = torch.long, device = args.device)
            current_batch_size=tgt_in_batch.size()[0]
            for j in range(current_batch_size):
                tgt_in_batch[j][0] = torch.LongTensor(SEP_ID) #torch.LongTensor(CLS_ID) for tencentpretrain model

            with torch.no_grad():
                memory_bank, emb = model(src_batch, None, seg_batch, None, only_use_encoder=True)

            step = 0
            scores = torch.zeros(current_batch_size, beam_width, tgt_seq_length)
            tokens = torch.zeros(current_batch_size, beam_width, tgt_seq_length+1, dtype = torch.long)
            tokens[:,:,0] = torch.LongTensor(args.tokenizer.convert_tokens_to_ids([SEP_TOKEN])) #2
            emb = emb.repeat(1, beam_width, 1).reshape(current_batch_size * beam_width, -1, int(args.conv_channels[-1] / 2)) #same batch nearby
            memory_bank = memory_bank.repeat(1, beam_width, 1).reshape(current_batch_size * beam_width, -1, args.emb_size) 
            tgt_in_batch = tgt_in_batch.repeat(beam_width, 1)
            while step < tgt_seq_length and step < seq_length:
                with torch.no_grad():
                    outputs = model(emb, (tgt_in_batch, None, None), None, None, memory_bank=memory_bank)

                vocab_size = outputs.shape[-1]
                log_prob = F.log_softmax(outputs[:, -1, :], dim=-1) #(B*beam_size) * 1 * vocab_size
                log_prob = log_prob.squeeze() #(B*beam_size) * vocab_size

                log_prob[:,PAD_ID] = -math.inf # do not select pad
                if step == 0:
                    log_prob[:,SEP_ID] = -math.inf # </s>

                log_prob_beam = log_prob.reshape(current_batch_size, beam_width, -1) # B * beam * vocab_size

                if step == 0:
                    log_prob_beam = log_prob_beam[:, ::beam_width, :].contiguous().to(scores.device)
                else:
                    log_prob_beam = log_prob_beam.to(scores.device) + scores[:,:, step-1].unsqueeze(-1)
                
                top_prediction_prob, top_prediction_indices = torch.topk(log_prob_beam.view(current_batch_size, -1), k=beam_width)
                beams_buf = torch.div(top_prediction_indices, vocab_size).trunc().long()
                beams_buf = beams_buf + torch.arange(current_batch_size).repeat(beam_width).reshape(beam_width,-1).transpose(0,1) * beam_width
                top_prediction_indices = top_prediction_indices.fmod(vocab_size)
                
                scores[:, :, step] = top_prediction_prob
                tokens[:, :, step+1] = top_prediction_indices

                if step > 0 and current_batch_size == 1:
                    tokens[:, :, :step+1] = torch.index_select(tokens, dim=1, index=beams_buf.squeeze())[:, :, :step+1]
                elif step > 0:
                    tokens[:, :, step+1] = torch.index_select(tokens.reshape(-1,tokens.shape[2]), dim=0, index=beams_buf.reshape(-1)).reshape(current_batch_size, -1, tokens.shape[2])[:, :, step+1]
                tgt_in_batch = tokens[:, :, :step+2].view(current_batch_size * beam_width, -1)
                tgt_in_batch = tgt_in_batch.long().to(emb.device)

                step = step + 1
            for i in range(current_batch_size):
                for j in range(1):
                    res = "".join([args.tokenizer.inv_vocab[token_id.item()] for token_id in tokens[i,j,:]])
                    res = res.split(SEP_TOKEN)[1].split(CLS_TOKEN)[0] # res.split(CLS_TOKEN)[1].split(SEP_TOKEN)[0] for tencentpretrain model
                    res = res.replace('▁',' ')

                    f.write(res)
                    f.write("\n")


if __name__ == "__main__":
    main()