# Generated 2021-03-09 from:
# /home/mila/s/subakany/speechbrain_new/recipes/WSJ0Mix/separation/train/hparams/sepformer-wham.yaml
# yamllint disable
# ################################
# Model: SepFormer for source separation
# https://arxiv.org/abs/2010.13154
#
# Dataset : WSJ0-2mix and WSJ0-3mix
# ################################
# Basic parameters
# Seed needs to be set at top of yaml, before objects with parameters are made
#
seed: 1234
__set_seed: !apply:torch.manual_seed [1234]

# Data params

# the data folder for the wham dataset
# needs to end with wham_original for the wham dataset
# needs to end with wham_reverb for the whamr dataset
data_folder: /network/tmp1/subakany/wham_original

# the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used
# e.g. /yourpath/wsj0-processed/si_tr_s/
wsj0_tr: /yourpath/wsj0-processed/si_tr_s/

experiment_name: sepformer-wham
output_folder: results/sepformer-wham/1234
train_log: results/sepformer-wham/1234/train_log.txt
save_folder: results/sepformer-wham/1234/save

# the file names should start with whamr instead of whamorg
train_data: results/sepformer-wham/1234/save/whamorg_tr.csv
valid_data: results/sepformer-wham/1234/save/whamorg_cv.csv
test_data: results/sepformer-wham/1234/save/whamorg_tt.csv
skip_prep: false


# Experiment params
auto_mix_prec: false # Set it to True for mixed precision
test_only: false
num_spks: 2 # set to 3 for wsj0-3mix
progressbar: true
save_audio: false # Save estimated sources on disk
sample_rate: 8000

# Training parameters
N_epochs: 200
batch_size: 1
lr: 0.00015
clip_grad_norm: 5
loss_upper_lim: 999999  # this is the upper limit for an acceptable loss
# if True, the training sequences are cut to a specified length
limit_training_signal_len: false
# this is the length of sequences if we choose to limit
# the signal length of training sequences
training_signal_len: 32000000

# Set it to True to dynamically create mixtures at training time
dynamic_mixing: false

# Parameters for data augmentation
use_wavedrop: false
use_speedperturb: true
use_speedperturb_sameforeachsource: false
use_rand_shift: false
min_shift: -8000
max_shift: 8000

speedperturb: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
  perturb_prob: 1.0
  drop_freq_prob: 0.0
  drop_chunk_prob: 0.0
  sample_rate: 8000
  speeds: [95, 100, 105]

wavedrop: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
  perturb_prob: 0.0
  drop_freq_prob: 1.0
  drop_chunk_prob: 1.0
  sample_rate: 8000

# loss thresholding -- this thresholds the training loss
threshold_byloss: true
threshold: -30

# Encoder parameters
N_encoder_out: 256
out_channels: 256
kernel_size: 16
kernel_stride: 8

# Dataloader options
dataloader_opts:
  batch_size: 1
  num_workers: 3


# Specifying the network
Encoder: &id003 !new:speechbrain.lobes.models.dual_path.Encoder
  kernel_size: 16
  out_channels: 256


SBtfintra: &id001 !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
  num_layers: 8
  d_model: 256
  nhead: 8
  d_ffn: 1024
  dropout: 0
  use_positional_encoding: true
  norm_before: true

SBtfinter: &id002 !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
  num_layers: 8
  d_model: 256
  nhead: 8
  d_ffn: 1024
  dropout: 0
  use_positional_encoding: true
  norm_before: true

MaskNet: &id005 !new:speechbrain.lobes.models.dual_path.Dual_Path_Model

  num_spks: 2
  in_channels: 256
  out_channels: 256
  num_layers: 2
  K: 250
  intra_model: *id001
  inter_model: *id002
  norm: ln
  linear_layer_after_inter_intra: false
  skip_around_intra: true

Decoder: &id004 !new:speechbrain.lobes.models.dual_path.Decoder
  in_channels: 256
  out_channels: 1
  kernel_size: 16
  stride: 8
  bias: false

optimizer: !name:torch.optim.Adam
  lr: 0.00015
  weight_decay: 0

loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper

lr_scheduler: &id007 !new:speechbrain.nnet.schedulers.ReduceLROnPlateau

  factor: 0.5
  patience: 2
  dont_halve_until_epoch: 65

epoch_counter: &id006 !new:speechbrain.utils.epoch_loop.EpochCounter
  limit: 200

modules:
  encoder: *id003
  decoder: *id004
  masknet: *id005
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
  checkpoints_dir: results/sepformer-wham/1234/save
  recoverables:
    encoder: *id003
    decoder: *id004
    masknet: *id005
    counter: *id006
    lr_scheduler: *id007
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
  save_file: results/sepformer-wham/1234/train_log.txt

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
  loadables:
    masknet: !ref <MaskNet>
    encoder: !ref <Encoder>
    decoder: !ref <Decoder>