|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from modules.wenet_extractor.transducer.joint import TransducerJoint |
|
from modules.wenet_extractor.transducer.predictor import ( |
|
ConvPredictor, |
|
EmbeddingPredictor, |
|
RNNPredictor, |
|
) |
|
from modules.wenet_extractor.transducer.transducer import Transducer |
|
from modules.wenet_extractor.transformer.asr_model import ASRModel |
|
from modules.wenet_extractor.transformer.cmvn import GlobalCMVN |
|
from modules.wenet_extractor.transformer.ctc import CTC |
|
from modules.wenet_extractor.transformer.decoder import ( |
|
BiTransformerDecoder, |
|
TransformerDecoder, |
|
) |
|
from modules.wenet_extractor.transformer.encoder import ( |
|
ConformerEncoder, |
|
TransformerEncoder, |
|
) |
|
from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder |
|
from modules.wenet_extractor.efficient_conformer.encoder import ( |
|
EfficientConformerEncoder, |
|
) |
|
from modules.wenet_extractor.paraformer.paraformer import Paraformer |
|
from modules.wenet_extractor.cif.predictor import Predictor |
|
from modules.wenet_extractor.utils.cmvn import load_cmvn |
|
|
|
|
|
def init_model(configs): |
|
if configs["cmvn_file"] is not None: |
|
mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"]) |
|
global_cmvn = GlobalCMVN( |
|
torch.from_numpy(mean).float(), torch.from_numpy(istd).float() |
|
) |
|
else: |
|
global_cmvn = None |
|
|
|
input_dim = configs["input_dim"] |
|
vocab_size = configs["output_dim"] |
|
|
|
encoder_type = configs.get("encoder", "conformer") |
|
decoder_type = configs.get("decoder", "bitransformer") |
|
|
|
if encoder_type == "conformer": |
|
encoder = ConformerEncoder( |
|
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] |
|
) |
|
elif encoder_type == "squeezeformer": |
|
encoder = SqueezeformerEncoder( |
|
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] |
|
) |
|
elif encoder_type == "efficientConformer": |
|
encoder = EfficientConformerEncoder( |
|
input_dim, |
|
global_cmvn=global_cmvn, |
|
**configs["encoder_conf"], |
|
**( |
|
configs["encoder_conf"]["efficient_conf"] |
|
if "efficient_conf" in configs["encoder_conf"] |
|
else {} |
|
), |
|
) |
|
else: |
|
encoder = TransformerEncoder( |
|
input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] |
|
) |
|
if decoder_type == "transformer": |
|
decoder = TransformerDecoder( |
|
vocab_size, encoder.output_size(), **configs["decoder_conf"] |
|
) |
|
else: |
|
assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0 |
|
assert configs["decoder_conf"]["r_num_blocks"] > 0 |
|
decoder = BiTransformerDecoder( |
|
vocab_size, encoder.output_size(), **configs["decoder_conf"] |
|
) |
|
ctc = CTC(vocab_size, encoder.output_size()) |
|
|
|
|
|
if "predictor" in configs: |
|
predictor_type = configs.get("predictor", "rnn") |
|
if predictor_type == "rnn": |
|
predictor = RNNPredictor(vocab_size, **configs["predictor_conf"]) |
|
elif predictor_type == "embedding": |
|
predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"]) |
|
configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ |
|
"embed_size" |
|
] |
|
elif predictor_type == "conv": |
|
predictor = ConvPredictor(vocab_size, **configs["predictor_conf"]) |
|
configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ |
|
"embed_size" |
|
] |
|
else: |
|
raise NotImplementedError("only rnn, embedding and conv type support now") |
|
configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][ |
|
"output_size" |
|
] |
|
configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][ |
|
"output_size" |
|
] |
|
joint = TransducerJoint(vocab_size, **configs["joint_conf"]) |
|
model = Transducer( |
|
vocab_size=vocab_size, |
|
blank=0, |
|
predictor=predictor, |
|
encoder=encoder, |
|
attention_decoder=decoder, |
|
joint=joint, |
|
ctc=ctc, |
|
**configs["model_conf"], |
|
) |
|
elif "paraformer" in configs: |
|
predictor = Predictor(**configs["cif_predictor_conf"]) |
|
model = Paraformer( |
|
vocab_size=vocab_size, |
|
encoder=encoder, |
|
decoder=decoder, |
|
ctc=ctc, |
|
predictor=predictor, |
|
**configs["model_conf"], |
|
) |
|
else: |
|
model = ASRModel( |
|
vocab_size=vocab_size, |
|
encoder=encoder, |
|
decoder=decoder, |
|
ctc=ctc, |
|
lfmmi_dir=configs.get("lfmmi_dir", ""), |
|
**configs["model_conf"], |
|
) |
|
return model |
|
|