|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from itertools import chain |
|
from typing import Any |
|
from typing import Dict |
|
from typing import List |
|
from typing import Tuple |
|
from typing import Union |
|
from typing import NamedTuple |
|
|
|
import torch |
|
|
|
from modules.wenet_extractor.paraformer.utils import end_detect |
|
from modules.wenet_extractor.paraformer.search.ctc import CTCPrefixScorer |
|
from modules.wenet_extractor.paraformer.search.scorer_interface import ( |
|
ScorerInterface, |
|
PartialScorerInterface, |
|
) |
|
|
|
|
|
class Hypothesis(NamedTuple): |
|
"""Hypothesis data type.""" |
|
|
|
yseq: torch.Tensor |
|
score: Union[float, torch.Tensor] = 0 |
|
scores: Dict[str, Union[float, torch.Tensor]] = dict() |
|
states: Dict[str, Any] = dict() |
|
|
|
def asdict(self) -> dict: |
|
"""Convert data to JSON-friendly dict.""" |
|
return self._replace( |
|
yseq=self.yseq.tolist(), |
|
score=float(self.score), |
|
scores={k: float(v) for k, v in self.scores.items()}, |
|
)._asdict() |
|
|
|
|
|
class BeamSearchCIF(torch.nn.Module): |
|
"""Beam search implementation.""" |
|
|
|
def __init__( |
|
self, |
|
scorers: Dict[str, ScorerInterface], |
|
weights: Dict[str, float], |
|
beam_size: int, |
|
vocab_size: int, |
|
sos: int, |
|
eos: int, |
|
pre_beam_ratio: float = 1.5, |
|
pre_beam_score_key: str = None, |
|
): |
|
"""Initialize beam search. |
|
|
|
Args: |
|
scorers (dict[str, ScorerInterface]): Dict of decoder modules |
|
e.g., Decoder, CTCPrefixScorer, LM |
|
The scorer will be ignored if it is `None` |
|
weights (dict[str, float]): Dict of weights for each scorers |
|
The scorer will be ignored if its weight is 0 |
|
beam_size (int): The number of hypotheses kept during search |
|
vocab_size (int): The number of vocabulary |
|
sos (int): Start of sequence id |
|
eos (int): End of sequence id |
|
pre_beam_score_key (str): key of scores to perform pre-beam search |
|
pre_beam_ratio (float): beam size in the pre-beam search |
|
will be `int(pre_beam_ratio * beam_size)` |
|
|
|
""" |
|
super().__init__() |
|
|
|
self.weights = weights |
|
self.scorers = dict() |
|
self.full_scorers = dict() |
|
self.part_scorers = dict() |
|
|
|
|
|
self.nn_dict = torch.nn.ModuleDict() |
|
for k, v in scorers.items(): |
|
w = weights.get(k, 0) |
|
if w == 0 or v is None: |
|
continue |
|
assert isinstance( |
|
v, ScorerInterface |
|
), f"{k} ({type(v)}) does not implement ScorerInterface" |
|
self.scorers[k] = v |
|
if isinstance(v, PartialScorerInterface): |
|
self.part_scorers[k] = v |
|
else: |
|
self.full_scorers[k] = v |
|
if isinstance(v, torch.nn.Module): |
|
self.nn_dict[k] = v |
|
|
|
|
|
self.sos = sos |
|
self.eos = eos |
|
self.pre_beam_size = int(pre_beam_ratio * beam_size) |
|
self.beam_size = beam_size |
|
self.n_vocab = vocab_size |
|
if ( |
|
pre_beam_score_key is not None |
|
and pre_beam_score_key != "full" |
|
and pre_beam_score_key not in self.full_scorers |
|
): |
|
raise KeyError( |
|
f"{pre_beam_score_key} is not found in " f"{self.full_scorers}" |
|
) |
|
self.pre_beam_score_key = pre_beam_score_key |
|
self.do_pre_beam = ( |
|
self.pre_beam_score_key is not None |
|
and self.pre_beam_size < self.n_vocab |
|
and len(self.part_scorers) > 0 |
|
) |
|
|
|
def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: |
|
"""Get an initial hypothesis data. |
|
|
|
Args: |
|
x (torch.Tensor): The encoder output feature |
|
|
|
Returns: |
|
Hypothesis: The initial hypothesis. |
|
|
|
""" |
|
init_states = dict() |
|
init_scores = dict() |
|
for k, d in self.scorers.items(): |
|
init_states[k] = d.init_state(x) |
|
init_scores[k] = 0.0 |
|
return [ |
|
Hypothesis( |
|
score=0.0, |
|
scores=init_scores, |
|
states=init_states, |
|
yseq=torch.tensor([self.sos], device=x.device), |
|
) |
|
] |
|
|
|
@staticmethod |
|
def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: |
|
"""Append new token to prefix tokens. |
|
|
|
Args: |
|
xs (torch.Tensor): The prefix token |
|
x (int): The new token to append |
|
|
|
Returns: |
|
torch.Tensor: New tensor contains: xs + [x] with xs.dtype and |
|
xs.device |
|
|
|
""" |
|
x = torch.tensor([x], dtype=xs.dtype, device=xs.device) |
|
return torch.cat((xs, x)) |
|
|
|
def score_full( |
|
self, hyp: Hypothesis, x: torch.Tensor |
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
|
"""Score new hypothesis by `self.full_scorers`. |
|
|
|
Args: |
|
hyp (Hypothesis): Hypothesis with prefix tokens to score |
|
x (torch.Tensor): Corresponding input feature |
|
|
|
Returns: |
|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
|
score dict of `hyp` that has string keys of `self.full_scorers` |
|
and tensor score values of shape: `(self.n_vocab,)`, |
|
and state dict that has string keys |
|
and state values of `self.full_scorers` |
|
|
|
""" |
|
scores = dict() |
|
states = dict() |
|
for k, d in self.full_scorers.items(): |
|
scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) |
|
return scores, states |
|
|
|
def score_partial( |
|
self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor |
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
|
"""Score new hypothesis by `self.part_scorers`. |
|
|
|
Args: |
|
hyp (Hypothesis): Hypothesis with prefix tokens to score |
|
ids (torch.Tensor): 1D tensor of new partial tokens to score |
|
x (torch.Tensor): Corresponding input feature |
|
|
|
Returns: |
|
Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
|
score dict of `hyp` that has string keys of `self.part_scorers` |
|
and tensor score values of shape: `(len(ids),)`, |
|
and state dict that has string keys |
|
and state values of `self.part_scorers` |
|
|
|
""" |
|
scores = dict() |
|
states = dict() |
|
for k, d in self.part_scorers.items(): |
|
scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) |
|
return scores, states |
|
|
|
def beam( |
|
self, weighted_scores: torch.Tensor, ids: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Compute topk full token ids and partial token ids. |
|
|
|
Args: |
|
weighted_scores (torch.Tensor): The weighted sum scores for each |
|
tokens. |
|
Its shape is `(self.n_vocab,)`. |
|
ids (torch.Tensor): The partial token ids to compute topk |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: |
|
The topk full token ids and partial token ids. |
|
Their shapes are `(self.beam_size,)` |
|
|
|
""" |
|
|
|
if weighted_scores.size(0) == ids.size(0): |
|
top_ids = weighted_scores.topk(self.beam_size)[1] |
|
return top_ids, top_ids |
|
|
|
|
|
tmp = weighted_scores[ids] |
|
weighted_scores[:] = -float("inf") |
|
weighted_scores[ids] = tmp |
|
top_ids = weighted_scores.topk(self.beam_size)[1] |
|
local_ids = weighted_scores[ids].topk(self.beam_size)[1] |
|
return top_ids, local_ids |
|
|
|
@staticmethod |
|
def merge_scores( |
|
prev_scores: Dict[str, float], |
|
next_full_scores: Dict[str, torch.Tensor], |
|
full_idx: int, |
|
next_part_scores: Dict[str, torch.Tensor], |
|
part_idx: int, |
|
) -> Dict[str, torch.Tensor]: |
|
"""Merge scores for new hypothesis. |
|
|
|
Args: |
|
prev_scores (Dict[str, float]): |
|
The previous hypothesis scores by `self.scorers` |
|
next_full_scores (Dict[str, torch.Tensor]): scores by |
|
`self.full_scorers` |
|
full_idx (int): The next token id for `next_full_scores` |
|
next_part_scores (Dict[str, torch.Tensor]): |
|
scores of partial tokens by `self.part_scorers` |
|
part_idx (int): The new token id for `next_part_scores` |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: The new score dict. |
|
Its keys are names of `self.full_scorers` and |
|
`self.part_scorers`. |
|
Its values are scalar tensors by the scorers. |
|
|
|
""" |
|
new_scores = dict() |
|
for k, v in next_full_scores.items(): |
|
new_scores[k] = prev_scores[k] + v[full_idx] |
|
for k, v in next_part_scores.items(): |
|
new_scores[k] = prev_scores[k] + v[part_idx] |
|
return new_scores |
|
|
|
def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: |
|
"""Merge states for new hypothesis. |
|
|
|
Args: |
|
states: states of `self.full_scorers` |
|
part_states: states of `self.part_scorers` |
|
part_idx (int): The new token id for `part_scores` |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: The new score dict. |
|
Its keys are names of `self.full_scorers` and |
|
`self.part_scorers`. |
|
Its values are states of the scorers. |
|
|
|
""" |
|
new_states = dict() |
|
for k, v in states.items(): |
|
new_states[k] = v |
|
for k, d in self.part_scorers.items(): |
|
new_states[k] = d.select_state(part_states[k], part_idx) |
|
return new_states |
|
|
|
def search( |
|
self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor |
|
) -> List[Hypothesis]: |
|
"""Search new tokens for running hypotheses and encoded speech x. |
|
|
|
Args: |
|
running_hyps (List[Hypothesis]): Running hypotheses on beam |
|
x (torch.Tensor): Encoded speech feature (T, D) |
|
|
|
Returns: |
|
List[Hypotheses]: Best sorted hypotheses |
|
|
|
""" |
|
best_hyps = [] |
|
part_ids = torch.arange(self.n_vocab, device=x.device) |
|
for hyp in running_hyps: |
|
|
|
weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) |
|
weighted_scores += am_score |
|
scores, states = self.score_full(hyp, x) |
|
for k in self.full_scorers: |
|
weighted_scores += self.weights[k] * scores[k] |
|
|
|
if self.do_pre_beam: |
|
pre_beam_scores = ( |
|
weighted_scores |
|
if self.pre_beam_score_key == "full" |
|
else scores[self.pre_beam_score_key] |
|
) |
|
part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] |
|
part_scores, part_states = self.score_partial(hyp, part_ids, x) |
|
for k in self.part_scorers: |
|
weighted_scores[part_ids] += self.weights[k] * part_scores[k] |
|
|
|
weighted_scores += hyp.score |
|
|
|
|
|
for j, part_j in zip(*self.beam(weighted_scores, part_ids)): |
|
|
|
best_hyps.append( |
|
Hypothesis( |
|
score=weighted_scores[j], |
|
yseq=self.append_token(hyp.yseq, j), |
|
scores=self.merge_scores( |
|
hyp.scores, scores, j, part_scores, part_j |
|
), |
|
states=self.merge_states(states, part_states, part_j), |
|
) |
|
) |
|
|
|
|
|
best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ |
|
: min(len(best_hyps), self.beam_size) |
|
] |
|
return best_hyps |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
am_scores: torch.Tensor, |
|
maxlenratio: float = 0.0, |
|
minlenratio: float = 0.0, |
|
) -> List[Hypothesis]: |
|
"""Perform beam search. |
|
|
|
Args: |
|
x (torch.Tensor): Encoded speech feature (T, D) |
|
maxlenratio (float): Input length ratio to obtain max output length. |
|
If maxlenratio=0.0 (default), it uses a end-detect function |
|
to automatically find maximum hypothesis lengths |
|
If maxlenratio<0.0, its absolute value is interpreted |
|
as a constant max output length. |
|
minlenratio (float): Input length ratio to obtain min output length. |
|
|
|
Returns: |
|
list[Hypothesis]: N-best decoding results |
|
|
|
""" |
|
|
|
maxlen = am_scores.shape[0] |
|
|
|
|
|
running_hyps = self.init_hyp(x) |
|
ended_hyps = [] |
|
for i in range(maxlen): |
|
best = self.search(running_hyps, x, am_scores[i]) |
|
|
|
running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) |
|
|
|
if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): |
|
break |
|
|
|
nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) |
|
|
|
if len(nbest_hyps) == 0: |
|
return ( |
|
[] |
|
if minlenratio < 0.1 |
|
else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) |
|
) |
|
|
|
best = nbest_hyps[0] |
|
return nbest_hyps |
|
|
|
def post_process( |
|
self, |
|
i: int, |
|
maxlen: int, |
|
maxlenratio: float, |
|
running_hyps: List[Hypothesis], |
|
ended_hyps: List[Hypothesis], |
|
) -> List[Hypothesis]: |
|
"""Perform post-processing of beam search iterations. |
|
|
|
Args: |
|
i (int): The length of hypothesis tokens. |
|
maxlen (int): The maximum length of tokens in beam search. |
|
maxlenratio (int): The maximum length ratio in beam search. |
|
running_hyps (List[Hypothesis]): The running hypotheses in beam |
|
search. |
|
ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. |
|
|
|
Returns: |
|
List[Hypothesis]: The new running hypotheses. |
|
|
|
""" |
|
|
|
|
|
if i == maxlen - 1: |
|
|
|
running_hyps = [ |
|
h._replace(yseq=self.append_token(h.yseq, self.eos)) |
|
for h in running_hyps |
|
] |
|
|
|
|
|
|
|
|
|
remained_hyps = [] |
|
for hyp in running_hyps: |
|
if hyp.yseq[-1] == self.eos: |
|
|
|
for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): |
|
s = d.final_score(hyp.states[k]) |
|
hyp.scores[k] += s |
|
hyp = hyp._replace(score=hyp.score + self.weights[k] * s) |
|
ended_hyps.append(hyp) |
|
else: |
|
remained_hyps.append(hyp) |
|
return remained_hyps |
|
|
|
|
|
def build_beam_search(model, args, device): |
|
scorers = {} |
|
if model.ctc is not None: |
|
ctc = CTCPrefixScorer(ctc=model.ctc, eos=model.eos) |
|
scorers.update(ctc=ctc) |
|
weights = dict( |
|
decoder=1.0 - args.ctc_weight, |
|
ctc=args.ctc_weight, |
|
length_bonus=args.penalty, |
|
) |
|
beam_search = BeamSearchCIF( |
|
beam_size=args.beam_size, |
|
weights=weights, |
|
scorers=scorers, |
|
sos=model.sos, |
|
eos=model.eos, |
|
vocab_size=model.vocab_size, |
|
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", |
|
) |
|
beam_search.to(device=device, dtype=torch.float32).eval() |
|
return beam_search |
|
|