# by HuggingFace
# ############################################################################
# Model: E2E ASR with attention-based ASR
# Encoder: CRDNN model
# Decoder: GRU + beamsearch + RNNLM
# Tokens: BPE with unigram
# Authors:  Ju-Chieh Chou, Mirco Ravanelli, Abdel Heba, Peter Plantinga 2020
# ############################################################################


# Feature parameters
sample_rate: 16000
n_fft: 400
n_mels: 40

# Model parameters
activation: !name:torch.nn.LeakyReLU
dropout: 0.15
cnn_blocks: 2
cnn_channels: (128, 256)
inter_layer_pooling_size: (2, 2)
cnn_kernelsize: (3, 3)
time_pooling_size: 4
rnn_class: !name:speechbrain.nnet.RNN.LSTM
rnn_layers: 4
rnn_neurons: 1024
rnn_bidirectional: True
dnn_blocks: 2
dnn_neurons: 512
emb_size: 128
dec_neurons: 1024
output_neurons: 1000  # index(blank/eos/bos) = 0
blank_index: 0

# Decoding parameters
bos_index: 0
eos_index: 0
min_decode_ratio: 0.0
max_decode_ratio: 1.0
beam_size: 80
eos_threshold: 1.5
using_max_attn_shift: True
max_attn_shift: 240
lm_weight: 0.50
coverage_penalty: 1.5
temperature: 1.25
temperature_lm: 1.25

normalizer: !new:speechbrain.processing.features.InputNormalization
    norm_type: global

compute_features: !new:speechbrain.lobes.features.Fbank
    sample_rate: !ref <sample_rate>
    n_fft: !ref <n_fft>
    n_mels: !ref <n_mels>

enc: !new:speechbrain.lobes.models.CRDNN.CRDNN
    input_shape: [null, null, !ref <n_mels>]
    activation: !ref <activation>
    dropout: !ref <dropout>
    cnn_blocks: !ref <cnn_blocks>
    cnn_channels: !ref <cnn_channels>
    cnn_kernelsize: !ref <cnn_kernelsize>
    inter_layer_pooling_size: !ref <inter_layer_pooling_size>
    time_pooling: True
    using_2d_pooling: False
    time_pooling_size: !ref <time_pooling_size>
    rnn_class: !ref <rnn_class>
    rnn_layers: !ref <rnn_layers>
    rnn_neurons: !ref <rnn_neurons>
    rnn_bidirectional: !ref <rnn_bidirectional>
    rnn_re_init: True
    dnn_blocks: !ref <dnn_blocks>
    dnn_neurons: !ref <dnn_neurons>

emb: !new:speechbrain.nnet.embedding.Embedding
    num_embeddings: !ref <output_neurons>
    embedding_dim: !ref <emb_size>

dec: !new:speechbrain.nnet.RNN.AttentionalRNNDecoder
    enc_dim: !ref <dnn_neurons>
    input_size: !ref <emb_size>
    rnn_type: gru
    attn_type: location
    hidden_size: !ref <dec_neurons>
    attn_dim: 1024
    num_layers: 1
    scaling: 1.0
    channels: 10
    kernel_size: 100
    re_init: True
    dropout: !ref <dropout>

ctc_lin: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <dnn_neurons>
    n_neurons: !ref <output_neurons>

seq_lin: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <dec_neurons>
    n_neurons: !ref <output_neurons>

log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

lm_model: !new:speechbrain.lobes.models.RNNLM.RNNLM
    output_neurons: !ref <output_neurons>
    embedding_dim: !ref <emb_size>
    activation: !name:torch.nn.LeakyReLU
    dropout: 0.0
    rnn_layers: 2
    rnn_neurons: 2048
    dnn_blocks: 1
    dnn_neurons: 512
    return_hidden: True  # For inference

tokenizer: !new:sentencepiece.SentencePieceProcessor

asr_model: !new:torch.nn.ModuleList
    - [!ref <enc>, !ref <emb>, !ref <dec>, !ref <ctc_lin>, !ref <seq_lin>]

# We compose the inference (encoder) pipeline.
encoder: !new:speechbrain.nnet.containers.LengthsCapableSequential
    input_shape: [null, null, !ref <n_mels>]
    compute_features: !ref <compute_features>
    normalize: !ref <normalizer>
    model: !ref <enc>

decoder: !new:speechbrain.decoders.S2SRNNBeamSearchLM
    embedding: !ref <emb>
    decoder: !ref <dec>
    linear: !ref <seq_lin>
    language_model: !ref <lm_model>
    bos_index: !ref <bos_index>
    eos_index: !ref <eos_index>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <beam_size>
    eos_threshold: !ref <eos_threshold>
    using_max_attn_shift: !ref <using_max_attn_shift>
    max_attn_shift: !ref <max_attn_shift>
    coverage_penalty: !ref <coverage_penalty>
    lm_weight: !ref <lm_weight>
    temperature: !ref <temperature>
    temperature_lm: !ref <temperature_lm>


modules: 
    normalizer: !ref <normalizer>
    encoder: !ref <encoder>
    decoder: !ref <decoder>
    lm_model: !ref <lm_model>

pretrained_path_local: /content/pretrained_models
pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    loadables:
        normalizer: !ref <normalizer>
        asr: !ref <asr_model>
        lm: !ref <lm_model>
        tokenizer: !ref <tokenizer>
    paths:
        lm: !ref <pretrained_path_local>/lm.ckpt
        tokenizer: !ref <pretrained_path_local>/1000_unigram.model
        asr: !ref <pretrained_path_local>/model.ckpt
        normalizer: !ref <pretrained_path_local>/normalizer.ckpt