Update models/resunet.py
Browse files- models/resunet.py +60 -0
models/resunet.py
CHANGED
@@ -652,4 +652,64 @@ class ResUNet30(nn.Module):
|
|
652 |
|
653 |
return output_dict
|
654 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
|
|
|
|
652 |
|
653 |
return output_dict
|
654 |
|
655 |
+
|
656 |
+
@torch.no_grad()
|
657 |
+
def chunk_inference(self, input_dict):
|
658 |
+
chunk_config = {
|
659 |
+
'NL': 1.0,
|
660 |
+
'NC': 3.0,
|
661 |
+
'NR': 1.0,
|
662 |
+
'RATE': self.sampling_rate
|
663 |
+
}
|
664 |
+
|
665 |
+
mixtures = input_dict['mixture']
|
666 |
+
conditions = input_dict['condition']
|
667 |
+
|
668 |
+
film_dict = self.film(
|
669 |
+
conditions=conditions,
|
670 |
+
)
|
671 |
+
|
672 |
+
NL = int(chunk_config['NL'] * chunk_config['RATE'])
|
673 |
+
NC = int(chunk_config['NC'] * chunk_config['RATE'])
|
674 |
+
NR = int(chunk_config['NR'] * chunk_config['RATE'])
|
675 |
+
|
676 |
+
L = mixtures.shape[2]
|
677 |
+
|
678 |
+
out_np = np.zeros([1, L])
|
679 |
+
|
680 |
+
WINDOW = NL + NC + NR
|
681 |
+
current_idx = 0
|
682 |
+
|
683 |
+
while current_idx + WINDOW < L:
|
684 |
+
chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
|
685 |
+
|
686 |
+
chunk_out = self.base(
|
687 |
+
mixtures=chunk_in,
|
688 |
+
film_dict=film_dict,
|
689 |
+
)['waveform']
|
690 |
+
|
691 |
+
chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
|
692 |
+
|
693 |
+
if current_idx == 0:
|
694 |
+
out_np[:, current_idx:current_idx+WINDOW-NR] = \
|
695 |
+
chunk_out_np[:, :-NR] if NR != 0 else chunk_out_np
|
696 |
+
else:
|
697 |
+
out_np[:, current_idx+NL:current_idx+WINDOW-NR] = \
|
698 |
+
chunk_out_np[:, NL:-NR] if NR != 0 else chunk_out_np[:, NL:]
|
699 |
+
|
700 |
+
current_idx += NC
|
701 |
+
|
702 |
+
if current_idx < L:
|
703 |
+
chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
|
704 |
+
chunk_out = self.base(
|
705 |
+
mixtures=chunk_in,
|
706 |
+
film_dict=film_dict,
|
707 |
+
)['waveform']
|
708 |
+
|
709 |
+
chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
|
710 |
+
|
711 |
+
seg_len = chunk_out_np.shape[1]
|
712 |
+
out_np[:, current_idx + NL:current_idx + seg_len] = \
|
713 |
+
chunk_out_np[:, NL:]
|
714 |
|
715 |
+
return out_np
|