conex / espnet2 /tts /abs_tts.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
782 Bytes
from abc import ABC
from abc import abstractmethod
from typing import Dict
from typing import Tuple
import torch
class AbsTTS(torch.nn.Module, ABC):
@abstractmethod
def forward(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
spembs: torch.Tensor = None,
spcs: torch.Tensor = None,
spcs_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
raise NotImplementedError
@abstractmethod
def inference(
self,
text: torch.Tensor,
spembs: torch.Tensor = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
raise NotImplementedError