wasm-ar-tts / wasq /utils /plotting.py
wasmdashai's picture
first commit
7694c84
raw
history blame
1.02 kB
#import matplotlib
# matplotlib.use("Agg")
import matplotlib.pylab as plt
def get_alignment_figure(img):
fig = plt.figure(figsize=(6, 4))
plt.imshow(img, aspect='auto', origin='lower',
interpolation='none')
plt.xlabel('Spectrogram frame')
plt.ylabel('Input token')
plt.colorbar()
plt.tight_layout()
return fig
def get_spectrogram_figure(spec):
fig = plt.figure(figsize=(12, 3))
plt.imshow(spec, aspect='auto', origin='lower',
interpolation='none')
plt.xlabel('Frame')
plt.ylabel('Channel')
plt.colorbar()
plt.tight_layout()
return fig
def get_specs_figure(specs, xlabels):
n = len(specs)
fig, axes = plt.subplots(n, 1, figsize=(12, 3*n))
for i, ax in enumerate(axes):
im = ax.imshow(specs[i], aspect='auto', origin='lower',
interpolation='none')
ax.set_xlabel(xlabels[i])
ax.set_ylabel('Channel')
plt.colorbar(im, ax=ax)
plt.tight_layout()
return fig