#! /usr/bin/env python3

# This script takes a directory of MIDI files and converts them to images
# It was intended for the P909 dataset
# It will create a directory of images for each MIDI file, where each image is a frame of the MIDI file

# if you want to use the chord extractor, you need to run chord extractor on all the files, and then output
# a list of all the unique chord names to all_chords.txt e.g.
# $ cat ~/datasets/jsb_chorales_midi/*/*_chords.txt | awk -F'\t' '{print $3}' | sort | uniq -c | awk '{print $2}' > all_chords.txt 
#   cat midis/*_chords.txt | awk -F'\t' '{print $3}' | sort | uniq -c | awk '{print $2}' > all_chords.txt
#   find midis -name '*_chords.txt' -exec cat {} + | awk -F'\t' '{print $3}' | sort | uniq -c | awk '{print $2}' > all_chords.txt

import os
import sys
from multiprocessing import Pool, cpu_count, set_start_method
from tqdm import tqdm
from control_toys.data import fast_scandir
from functools import partial
import argparse
from control_toys.pianoroll import midi_to_pr_img
from control_toys.chords import simplify_chord, POSSIBLE_CHORDS

def wrapper(args, midi_file, all_chords=None):
    return midi_to_pr_img(midi_file, args.output_dir, show_chords=args.chords, all_chords=all_chords, 
                          chord_names=args.chord_names, filter_mp=args.filter_mp, add_onsets=args.onsets,
                          remove_leading_silence=(not args.silence))


if __name__ == '__main__':
    p = argparse.ArgumentParser(description=__doc__,
                                formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    p.add_argument('-c','--chords', action='store_true', help='infer chords and add markers')
    p.add_argument('--chord-names', action='store_true', help='Add text for chord names')
    p.add_argument('--filter-mp', default=True, help='filter out non-piano, non-melody instruments')
    #p.add_argument('--onsets', default=True, type=bool, help='add onset markers')
    p.add_argument('--onsets', default=True, action=argparse.BooleanOptionalAction, help='Produce onset markers')   # either --onsets or --no-onsets, default is...? 
    p.add_argument('--silence', default=True, action=argparse.BooleanOptionalAction, help='Leave silence at start of song (True) or remove it (False)')   
    p.add_argument('--start-method', type=str, default='fork',
                   choices=['fork', 'forkserver', 'spawn'],
                   help='the multiprocessing start method')
    p.add_argument('--simplify', action='store_true', help='Simplify chord types, e.g. remove 13s')
    p.add_argument('--skip-versions', default=True, help='skip extra versions of the same song')
    p.add_argument("midi_dirs", nargs='+', help="directories containing MIDI files")
    p.add_argument("output_dir", help="output directory")
    args = p.parse_args()
    print("args = ",args)

    set_start_method(args.start_method)
    midi_dirs, output_dir = args.midi_dirs, args.output_dir

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if os.path.isdir(midi_dirs[0]):
        midi_files = []
        for mdir in midi_dirs:
            m_subdirs, mf = fast_scandir(mdir, ['mid', 'midi'])
            if mf != []: midi_files = midi_files + mf
    elif os.path.isfile(midi_dirs[0]):
        midi_files = midi_dirs


    if args.skip_versions: 
        midi_files = [f for f in midi_files if '/versions/' not in f]
    print("len(midi_files) = ",len(midi_files)) # just a check for debugging
               
    if args.chords: 
        # TODO: this is janky af but for now...
        # Get a list of all unique chords from a premade text file list of possible chords
        # to make the file in bash, assuming you've already run extract_chords.py: (leave sort alphabetical, don't sort in order of freq)
        # cat */*_chords.txt | awk -F'\t' '{print $3}' | sort | uniq -c | awk '{print $2}' > all_chords.txt
        #with open('all_chords.txt') as f:
        #    all_chords = f.read().splitlines()
        # use possible chords as all chords
        all_chords = POSSIBLE_CHORDS  # now we just generate these
        if args.simplify:
            all_chords = list(set([simplify_chord(c) for c in all_chords]))
        print("len(all_chords) = ",len(all_chords))   
        print("all_chords = ",all_chords) # just a check for debugging  
    else:
        all_chords = None

    process_one = partial(wrapper, args, all_chords=all_chords)
    num_cpus = cpu_count()
    with Pool(num_cpus) as p:
        list(tqdm(p.imap(process_one, midi_files), total=len(midi_files), desc='Processing MIDI files'))

    print("Finished")