|
from typing import List, Tuple |
|
|
|
import torch |
|
from modules.wenet_extractor.utils.common import log_add |
|
|
|
|
|
class Sequence: |
|
__slots__ = {"hyp", "score", "cache"} |
|
|
|
def __init__( |
|
self, |
|
hyp: List[torch.Tensor], |
|
score, |
|
cache: List[torch.Tensor], |
|
): |
|
self.hyp = hyp |
|
self.score = score |
|
self.cache = cache |
|
|
|
|
|
class PrefixBeamSearch: |
|
def __init__(self, encoder, predictor, joint, ctc, blank): |
|
self.encoder = encoder |
|
self.predictor = predictor |
|
self.joint = joint |
|
self.ctc = ctc |
|
self.blank = blank |
|
|
|
def forward_decoder_one_step( |
|
self, encoder_x: torch.Tensor, pre_t: torch.Tensor, cache: List[torch.Tensor] |
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device) |
|
pre_t, new_cache = self.predictor.forward_step( |
|
pre_t.unsqueeze(-1), padding, cache |
|
) |
|
x = self.joint(encoder_x, pre_t) |
|
x = x.log_softmax(dim=-1) |
|
return x, new_cache |
|
|
|
def prefix_beam_search( |
|
self, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
decoding_chunk_size: int = -1, |
|
beam_size: int = 5, |
|
num_decoding_left_chunks: int = -1, |
|
simulate_streaming: bool = False, |
|
ctc_weight: float = 0.3, |
|
transducer_weight: float = 0.7, |
|
): |
|
"""prefix beam search |
|
also see wenet.transducer.transducer.beam_search |
|
""" |
|
assert speech.shape[0] == speech_lengths.shape[0] |
|
assert decoding_chunk_size != 0 |
|
device = speech.device |
|
batch_size = speech.shape[0] |
|
assert batch_size == 1 |
|
|
|
|
|
encoder_out, _ = self.encoder( |
|
speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks |
|
) |
|
maxlen = encoder_out.size(1) |
|
|
|
ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0) |
|
beam_init: List[Sequence] = [] |
|
|
|
|
|
cache = self.predictor.init_state(1, method="zero", device=device) |
|
beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache)) |
|
|
|
|
|
|
|
for i in range(maxlen): |
|
|
|
|
|
input_hyp = [s.hyp[-1] for s in beam_init] |
|
input_hyp_tensor = torch.tensor(input_hyp, dtype=torch.int, device=device) |
|
|
|
cache_batch = self.predictor.cache_to_batch([s.cache for s in beam_init]) |
|
|
|
scores = torch.tensor([s.score for s in beam_init]).to(device) |
|
|
|
|
|
logp, new_cache = self.forward_decoder_one_step( |
|
encoder_out[:, i, :].unsqueeze(1), |
|
input_hyp_tensor, |
|
cache_batch, |
|
) |
|
logp = logp.squeeze(1).squeeze(1) |
|
new_cache = self.predictor.batch_to_cache(new_cache) |
|
|
|
|
|
|
|
logp = torch.log( |
|
torch.add( |
|
transducer_weight * torch.exp(logp), |
|
ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0)), |
|
) |
|
) |
|
|
|
|
|
top_k_logp, top_k_index = logp.topk(beam_size) |
|
scores = torch.add(scores.unsqueeze(1), top_k_logp) |
|
|
|
|
|
beam_A = [] |
|
for j in range(len(beam_init)): |
|
|
|
base_seq = beam_init[j] |
|
for t in range(beam_size): |
|
|
|
if top_k_index[j, t] == self.blank: |
|
new_seq = Sequence( |
|
hyp=base_seq.hyp.copy(), |
|
score=scores[j, t].item(), |
|
cache=base_seq.cache, |
|
) |
|
|
|
beam_A.append(new_seq) |
|
|
|
else: |
|
hyp_new = base_seq.hyp.copy() |
|
hyp_new.append(top_k_index[j, t].item()) |
|
new_seq = Sequence( |
|
hyp=hyp_new, score=scores[j, t].item(), cache=new_cache[j] |
|
) |
|
beam_A.append(new_seq) |
|
|
|
|
|
fusion_A = [beam_A[0]] |
|
for j in range(1, len(beam_A)): |
|
s1 = beam_A[j] |
|
if_do_append = True |
|
for t in range(len(fusion_A)): |
|
|
|
if s1.hyp == fusion_A[t].hyp: |
|
fusion_A[t].score = log_add([fusion_A[t].score, s1.score]) |
|
if_do_append = False |
|
break |
|
if if_do_append: |
|
fusion_A.append(s1) |
|
|
|
|
|
fusion_A.sort(key=lambda x: x.score, reverse=True) |
|
beam_init = fusion_A[:beam_size] |
|
|
|
return beam_init, encoder_out |
|
|