|
import argparse |
|
|
|
import hparams |
|
import utils |
|
import multiresunet_model |
|
|
|
import tensorflow as tf |
|
import numpy as np |
|
|
|
if __name__ == '__main__': |
|
args = argparse.ArgumentParser() |
|
|
|
args.add_argument('Path',metavar='path',type=str,help='Path to DSD100 pickled spectrograms. See preprocess_data.py for more details') |
|
args.add_argument('Source',metavar='source',type=str,help='Desired source to separate') |
|
args.add_argument('Spectrum',metavar='spectrum',type=str,help='Low (lf) or High (hf) frequencies training') |
|
args.add_argument('Outpath',metavar='model_out_path',type=str,help='Path to save the model to') |
|
|
|
|
|
args = args.parse_args() |
|
path = args.Path |
|
source = args.Source |
|
spectrum = args.Spectrum |
|
output_path = args.Outpath |
|
|
|
|
|
x = np.load(path + 'mixture_' + spectrum + '.npy') |
|
y = np.load(path + source + '_' + spectrum + '.npy') |
|
|
|
|
|
model = multiresunet_model.Steminator((hparams.frequency_bins,hparams.chunk_size,hparams.n_channels)) |
|
optimizer = tf.keras.optimizers.Adam(lr = hparams.learning_rate) |
|
model.compile(optimizer, loss='mean_absolute_error') |
|
|
|
|
|
model.fit(x,y,epochs = hparams.epochs, batch_size = hparams.batch_size) |
|
|
|
|
|
model.save(output_path + source + '_' + spectrum + '.h5') |
|
|
|
|
|
|
|
|
|
|