""" Visualisation utils. """ import chess import chess.svg import matplotlib import matplotlib.pyplot as plt COLOR_MAP = matplotlib.colormaps["RdYlBu_r"].resampled(1000) ALPHA = 1.0 def render_heatmap( board, heatmap, square=None, vmin=None, vmax=None, arrows=None, normalise="none", ): """ Render a heatmap on the board. """ if normalise == "abs": a_max = heatmap.abs().max() if a_max != 0: heatmap = heatmap / a_max vmin = -1 vmax = 1 if vmin is None: vmin = heatmap.min() if vmax is None: vmax = heatmap.max() norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=False) color_dict = {} for square_index in range(64): color = COLOR_MAP(norm(heatmap[square_index])) color = (*color[:3], ALPHA) color_dict[square_index] = matplotlib.colors.to_hex(color, keep_alpha=True) fig = plt.figure(figsize=(6, 0.6)) ax = plt.gca() ax.axis("off") fig.colorbar( matplotlib.cm.ScalarMappable(norm=norm, cmap=COLOR_MAP), ax=ax, orientation="horizontal", fraction=1.0, ) if square is not None: try: check = chess.parse_square(square) except ValueError: check = None else: check = None if arrows is None: arrows = [] plt.close() return ( chess.svg.board( board, check=check, fill=color_dict, size=350, arrows=arrows, ), fig, )