|
from typing import List |
|
|
|
import torch |
|
|
|
|
|
def basic_greedy_search( |
|
model: torch.nn.Module, |
|
encoder_out: torch.Tensor, |
|
encoder_out_lens: torch.Tensor, |
|
n_steps: int = 64, |
|
) -> List[List[int]]: |
|
|
|
padding = torch.zeros(1, 1).to(encoder_out.device) |
|
|
|
pred_input_step = torch.tensor([model.blank]).reshape(1, 1) |
|
cache = model.predictor.init_state(1, method="zero", device=encoder_out.device) |
|
new_cache: List[torch.Tensor] = [] |
|
t = 0 |
|
hyps = [] |
|
prev_out_nblk = True |
|
pred_out_step = None |
|
per_frame_max_noblk = n_steps |
|
per_frame_noblk = 0 |
|
while t < encoder_out_lens: |
|
encoder_out_step = encoder_out[:, t : t + 1, :] |
|
if prev_out_nblk: |
|
step_outs = model.predictor.forward_step( |
|
pred_input_step, padding, cache |
|
) |
|
pred_out_step, new_cache = step_outs[0], step_outs[1] |
|
|
|
joint_out_step = model.joint(encoder_out_step, pred_out_step) |
|
joint_out_probs = joint_out_step.log_softmax(dim=-1) |
|
|
|
joint_out_max = joint_out_probs.argmax(dim=-1).squeeze() |
|
if joint_out_max != model.blank: |
|
hyps.append(joint_out_max.item()) |
|
prev_out_nblk = True |
|
per_frame_noblk = per_frame_noblk + 1 |
|
pred_input_step = joint_out_max.reshape(1, 1) |
|
|
|
cache = new_cache |
|
|
|
if joint_out_max == model.blank or per_frame_noblk >= per_frame_max_noblk: |
|
if joint_out_max == model.blank: |
|
prev_out_nblk = False |
|
|
|
|
|
t = t + 1 |
|
per_frame_noblk = 0 |
|
|
|
return [hyps] |
|
|