#!/usr/bin/env python3
"""
Copyright (c) 2020, Carleton University Biomedical Informatics Collaboratory

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import List
import os.path as path
import io

import PIL.Image
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib

from interfaces import Threshold

# Path to the folder containing the symbols that get plotted
SYMBOLS_DIR = path.join(path.dirname(__file__), "..", "assets", "symbols")

def figure_to_image(figure: matplotlib.pyplot.figure, size, dpi=300):
    """Converts a matplotlib figure to a PIL Image.

    Parameters
    ----------
    figure : matplotlib.pyplot.figure
    The figure to convert to an image.

    size: tuple
    Dimensions of the image in pixels.

    dpi: int
    Resolution (default: 300)

    Returns
    -------
    PIL.Image
    The image of the plot in PIL format.
    """
    # Save the figure to a buffer
    image_buffer = io.BytesIO()
    figure.set_size_inches((size[0]/dpi, size[1]/dpi))
    figure.savefig(image_buffer, dpi=dpi)

    # Open image from buffer
    image_buffer.seek(0)
    image = PIL.Image.open(image_buffer)

    return image

def plot_audiogram(thresholds: List[Threshold]) -> plt.figure:
    """Given a list of threshold dictionaries, plots the audiogram.

    Parameters
    ----------
    thresholds : List[Threshold]
    A list of dictionaries that implement the `Threshold` interface.

    Returns
    -------
    matplotlib.pyplot.figure
    The `Figure` object corresponding to the audiogram.
    """
    # Plot setup
    fig = plt.figure()
    ax = plt.gca()

    # Setup axes
    plt.xscale("log")
    ax.set_xlim(125, 8000) # Standard audios go from 125 to 8000 Hz 
    ax.set_ylim(-20, 120) # Most go from -10 to 120, but let's start at -20, in case
    plt.gca().invert_yaxis()

    # Add the ticks
    ax.set_xticks([125, 250, 500, 1000, 2000, 4000, 8000]) # Octave frequencies
    ax.set_yticks(list(range(-20, 120, 10))) # All multiples of 10
    ax.get_xaxis().set_major_formatter(
        matplotlib.ticker.FormatStrFormatter("%.0f")
    )

    # Show the grid
    plt.grid()

    # Iterate through the different symbols
    for ear in ("left", "right"):
        for masking in (True, False):
            for conduction in ("air", "bone"):
                # Filter out all other symbols to generate an individual curve
                curve = [
                    threshold 
                    for threshold
                    in thresholds 
                    if threshold["ear"] == ear
                    and threshold["masking"] == masking
                    and threshold["conduction"] == conduction
                ]

                # Sort the thresholds on the curve in order of frequency
                curve = sorted(curve, key=lambda t: t["frequency"])
                freq = [t["frequency"] for t in curve]
                threshold = [t["threshold"] for t in curve]

                # For every threshold (freq, thresh) belonging to the curve
                # considered in this loop iteration, add a symbol on the plot.
                for (f, t) in zip(freq, threshold):
                    icon_name = f"{ear}_{conduction}_{'masked' if masking else 'unmasked'}.png"
                    icon_img = plt.imread(path.join(SYMBOLS_DIR, icon_name))
                    icon = AnnotationBbox(
                                OffsetImage(icon_img, zoom=0.1), 
                                (f, t),
                                frameon=False
                            )
                    ax.add_artist(icon)

    return fig