File size: 5,692 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from typing import List, Tuple
import torch
from modules.wenet_extractor.utils.common import log_add
class Sequence:
__slots__ = {"hyp", "score", "cache"}
def __init__(
self,
hyp: List[torch.Tensor],
score,
cache: List[torch.Tensor],
):
self.hyp = hyp
self.score = score
self.cache = cache
class PrefixBeamSearch:
def __init__(self, encoder, predictor, joint, ctc, blank):
self.encoder = encoder
self.predictor = predictor
self.joint = joint
self.ctc = ctc
self.blank = blank
def forward_decoder_one_step(
self, encoder_x: torch.Tensor, pre_t: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device)
pre_t, new_cache = self.predictor.forward_step(
pre_t.unsqueeze(-1), padding, cache
)
x = self.joint(encoder_x, pre_t) # [beam, 1, 1, vocab]
x = x.log_softmax(dim=-1)
return x, new_cache
def prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
beam_size: int = 5,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
ctc_weight: float = 0.3,
transducer_weight: float = 0.7,
):
"""prefix beam search
also see wenet.transducer.transducer.beam_search
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
batch_size = speech.shape[0]
assert batch_size == 1
# 1. Encoder
encoder_out, _ = self.encoder(
speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0)
beam_init: List[Sequence] = []
# 2. init beam using Sequence to save beam unit
cache = self.predictor.init_state(1, method="zero", device=device)
beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache))
# 3. start decoding (notice: we use breathwise first searching)
# !!!! In this decoding method: one frame do not output multi units. !!!!
# !!!! Experiments show that this strategy has little impact !!!!
for i in range(maxlen):
# 3.1 building input
# decoder taking the last token to predict the next token
input_hyp = [s.hyp[-1] for s in beam_init]
input_hyp_tensor = torch.tensor(input_hyp, dtype=torch.int, device=device)
# building statement from beam
cache_batch = self.predictor.cache_to_batch([s.cache for s in beam_init])
# build score tensor to do torch.add() function
scores = torch.tensor([s.score for s in beam_init]).to(device)
# 3.2 forward decoder
logp, new_cache = self.forward_decoder_one_step(
encoder_out[:, i, :].unsqueeze(1),
input_hyp_tensor,
cache_batch,
) # logp: (N, 1, 1, vocab_size)
logp = logp.squeeze(1).squeeze(1) # logp: (N, vocab_size)
new_cache = self.predictor.batch_to_cache(new_cache)
# 3.3 shallow fusion for transducer score
# and ctc score where we can also add the LM score
logp = torch.log(
torch.add(
transducer_weight * torch.exp(logp),
ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0)),
)
)
# 3.4 first beam prune
top_k_logp, top_k_index = logp.topk(beam_size) # (N, N)
scores = torch.add(scores.unsqueeze(1), top_k_logp)
# 3.5 generate new beam (N*N)
beam_A = []
for j in range(len(beam_init)):
# update seq
base_seq = beam_init[j]
for t in range(beam_size):
# blank: only update the score
if top_k_index[j, t] == self.blank:
new_seq = Sequence(
hyp=base_seq.hyp.copy(),
score=scores[j, t].item(),
cache=base_seq.cache,
)
beam_A.append(new_seq)
# other unit: update hyp score statement and last
else:
hyp_new = base_seq.hyp.copy()
hyp_new.append(top_k_index[j, t].item())
new_seq = Sequence(
hyp=hyp_new, score=scores[j, t].item(), cache=new_cache[j]
)
beam_A.append(new_seq)
# 3.6 prefix fusion
fusion_A = [beam_A[0]]
for j in range(1, len(beam_A)):
s1 = beam_A[j]
if_do_append = True
for t in range(len(fusion_A)):
# notice: A_ can not fusion with A
if s1.hyp == fusion_A[t].hyp:
fusion_A[t].score = log_add([fusion_A[t].score, s1.score])
if_do_append = False
break
if if_do_append:
fusion_A.append(s1)
# 4. second pruned
fusion_A.sort(key=lambda x: x.score, reverse=True)
beam_init = fusion_A[:beam_size]
return beam_init, encoder_out
|