File size: 1,020 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#import matplotlib
# matplotlib.use("Agg")
import matplotlib.pylab as plt


def get_alignment_figure(img):
    fig = plt.figure(figsize=(6, 4))
    plt.imshow(img, aspect='auto', origin='lower',
               interpolation='none')
    plt.xlabel('Spectrogram frame')
    plt.ylabel('Input token')
    plt.colorbar()
    plt.tight_layout()
    return fig


def get_spectrogram_figure(spec):
    fig = plt.figure(figsize=(12, 3))
    plt.imshow(spec, aspect='auto', origin='lower',
               interpolation='none')
    plt.xlabel('Frame')
    plt.ylabel('Channel')
    plt.colorbar()
    plt.tight_layout()
    return fig


def get_specs_figure(specs, xlabels):
    n = len(specs)
    fig, axes = plt.subplots(n, 1, figsize=(12, 3*n))

    for i, ax in enumerate(axes):
        im = ax.imshow(specs[i], aspect='auto', origin='lower',
                       interpolation='none')
        ax.set_xlabel(xlabels[i])
        ax.set_ylabel('Channel')
        plt.colorbar(im, ax=ax)

    plt.tight_layout()
    return fig