import numpy as np
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision

from utils.richtext_utils import seed_everything
from sklearn.cluster import KMeans, SpectralClustering

# SelfAttentionLayers = [
#     # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
#     # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
#     'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
#     # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
#     'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
#     'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
#     'mid_block.attentions.0.transformer_blocks.0.attn1',
#     'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
#     'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
#     'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
#     # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
#     'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
#     # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
#     # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
#     # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
#     # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
# ]

SelfAttentionLayers = [
    # 'down_blocks.0.attentions.0.transformer_blocks.0.attn1',
    # 'down_blocks.0.attentions.1.transformer_blocks.0.attn1',
    'down_blocks.1.attentions.0.transformer_blocks.0.attn1',
    # 'down_blocks.1.attentions.1.transformer_blocks.0.attn1',
    'down_blocks.2.attentions.0.transformer_blocks.0.attn1',
    'down_blocks.2.attentions.1.transformer_blocks.0.attn1',
    'mid_block.attentions.0.transformer_blocks.0.attn1',
    'up_blocks.1.attentions.0.transformer_blocks.0.attn1',
    'up_blocks.1.attentions.1.transformer_blocks.0.attn1',
    'up_blocks.1.attentions.2.transformer_blocks.0.attn1',
    # 'up_blocks.2.attentions.0.transformer_blocks.0.attn1',
    'up_blocks.2.attentions.1.transformer_blocks.0.attn1',
    # 'up_blocks.2.attentions.2.transformer_blocks.0.attn1',
    # 'up_blocks.3.attentions.0.transformer_blocks.0.attn1',
    # 'up_blocks.3.attentions.1.transformer_blocks.0.attn1',
    # 'up_blocks.3.attentions.2.transformer_blocks.0.attn1',
]


CrossAttentionLayers = [
    # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
    # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
    'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
    # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
    'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
    'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
    'mid_block.attentions.0.transformer_blocks.0.attn2',
    'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
    'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
    'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
    # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
    'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
    # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
    # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
    # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
    # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
]

# CrossAttentionLayers = [
#     'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
#     'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
#     'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
#     'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
#     'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
#     'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
#     'mid_block.attentions.0.transformer_blocks.0.attn2',
#     'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
#     'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
#     'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
#     'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
#     'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
#     'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
#     'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
#     'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
#     'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
# ]

# CrossAttentionLayers_XL = [
#     'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
#     'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
#     'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
#     'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
#     'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
#     'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
#     'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
# ]
CrossAttentionLayers_XL = [
    'down_blocks.2.attentions.1.transformer_blocks.3.attn2',
    'down_blocks.2.attentions.1.transformer_blocks.4.attn2',
    'mid_block.attentions.0.transformer_blocks.0.attn2',
    'mid_block.attentions.0.transformer_blocks.1.attn2',
    'mid_block.attentions.0.transformer_blocks.2.attn2',
    'mid_block.attentions.0.transformer_blocks.3.attn2',
    'up_blocks.0.attentions.0.transformer_blocks.1.attn2',
    'up_blocks.0.attentions.0.transformer_blocks.2.attn2',
    'up_blocks.0.attentions.0.transformer_blocks.3.attn2',
    'up_blocks.0.attentions.0.transformer_blocks.4.attn2',
    'up_blocks.0.attentions.0.transformer_blocks.5.attn2',
    'up_blocks.0.attentions.0.transformer_blocks.6.attn2',
    'up_blocks.0.attentions.0.transformer_blocks.7.attn2',
    'up_blocks.1.attentions.0.transformer_blocks.0.attn2'
]

def split_attention_maps_over_steps(attention_maps):
    r"""Function for splitting attention maps over steps.
    Args:
        attention_maps (dict): Dictionary of attention maps.
        sampler_order (int): Order of the sampler.
    """
    # This function splits attention maps into unconditional and conditional score and over steps

    attention_maps_cond = dict()    # Maps corresponding to conditional score
    attention_maps_uncond = dict()  # Maps corresponding to unconditional score

    for layer in attention_maps.keys():

        for step_num in range(len(attention_maps[layer])):
            if step_num not in attention_maps_cond:
                attention_maps_cond[step_num] = dict()
                attention_maps_uncond[step_num] = dict()

            attention_maps_uncond[step_num].update(
                {layer: attention_maps[layer][step_num][:1]})
            attention_maps_cond[step_num].update(
                {layer: attention_maps[layer][step_num][1:2]})

    return attention_maps_cond, attention_maps_uncond


def save_attention_heatmaps(attention_maps, tokens_vis, save_dir, prefix):
    r"""Function to plot heatmaps for attention maps.

    Args:
        attention_maps (dict): Dictionary of attention maps per layer
        save_dir (str): Directory to save attention maps
        prefix (str): Filename prefix for html files

    Returns:
        Heatmaps, one per sample.
    """

    html_names = []

    idx = 0
    html_list = []

    for layer in attention_maps.keys():
        if idx == 0:
            # import ipdb;ipdb.set_trace()
            # create a set of html files.

            batch_size = attention_maps[layer].shape[0]

            for sample_num in range(batch_size):
                # html path
                html_rel_path = os.path.join('sample_{}'.format(
                    sample_num), '{}.html'.format(prefix))
                html_names.append(html_rel_path)
                html_path = os.path.join(save_dir, html_rel_path)
                os.makedirs(os.path.dirname(html_path), exist_ok=True)
                html_list.append(open(html_path, 'wt'))
                html_list[sample_num].write(
                    '<html><head></head><body><table>\n')

        for sample_num in range(batch_size):

            save_path = os.path.join(save_dir, 'sample_{}'.format(sample_num),
                                     prefix, 'layer_{}'.format(layer)) + '.jpg'
            Path(os.path.dirname(save_path)).mkdir(parents=True, exist_ok=True)

            layer_name = 'layer_{}'.format(layer)
            html_list[sample_num].write(
                f'<tr><td><h1>{layer_name}</h1></td></tr>\n')

            prefix_stem = prefix.split('/')[-1]
            relative_image_path = os.path.join(
                prefix_stem, 'layer_{}'.format(layer)) + '.jpg'
            html_list[sample_num].write(
                f'<tr><td><img src=\"{relative_image_path}\"></td></tr>\n')

            plt.figure()
            plt.clf()
            nrows = 2
            ncols = 7
            fig, axs = plt.subplots(nrows=nrows, ncols=ncols)

            fig.set_figheight(8)
            fig.set_figwidth(28.5)

            # axs[0].set_aspect('equal')
            # axs[1].set_aspect('equal')
            # axs[2].set_aspect('equal')
            # axs[3].set_aspect('equal')
            # axs[4].set_aspect('equal')
            # axs[5].set_aspect('equal')

            cmap = plt.get_cmap('YlOrRd')

            for rid in range(nrows):
                for cid in range(ncols):
                    tid = rid*ncols + cid
                    # import ipdb;ipdb.set_trace()
                    attention_map_cur = attention_maps[layer][sample_num, :, :, tid].numpy(
                    )
                    vmax = float(attention_map_cur.max())
                    vmin = float(attention_map_cur.min())
                    sns.heatmap(
                        attention_map_cur, annot=False, cbar=False, ax=axs[rid, cid],
                        cmap=cmap, vmin=vmin, vmax=vmax
                    )
                    axs[rid, cid].set_xlabel(tokens_vis[tid])

            # axs[0].set_xlabel('Self attention')
            # axs[1].set_xlabel('Temporal attention')
            # axs[2].set_xlabel('T5 text attention')
            # axs[3].set_xlabel('CLIP text attention')
            # axs[4].set_xlabel('CLIP image attention')
            # axs[5].set_xlabel('Null text token')

            norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
            # fig.colorbar(sm, cax=axs[6])

            fig.tight_layout()
            plt.savefig(save_path, dpi=64)
            plt.close('all')

        if idx == (len(attention_maps.keys()) - 1):
            for sample_num in range(batch_size):
                html_list[sample_num].write('</table></body></html>')
                html_list[sample_num].close()

        idx += 1

    return html_names


def create_recursive_html_link(html_path, save_dir):
    r"""Function for creating recursive html links.
    If the path is dir1/dir2/dir3/*.html,
    we create chained directories
        -dir1
            dir1.html (has links to all children)
            -dir2
                dir2.html   (has links to all children)
                -dir3
                    dir3.html

    Args:
        html_path (str): Path to html file.
        save_dir (str): Save directory.
    """

    html_path_split = os.path.splitext(html_path)[0].split('/')
    if len(html_path_split) == 1:
        return

    # First create the root directory
    root_dir = html_path_split[0]
    child_dir = html_path_split[1]

    cur_html_path = os.path.join(save_dir, '{}.html'.format(root_dir))
    if os.path.exists(cur_html_path):

        fp = open(cur_html_path, 'r')
        lines_written = fp.readlines()
        fp.close()

        fp = open(cur_html_path, 'a+')
        child_path = os.path.join(root_dir, f'{child_dir}.html')
        line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'

        if line_to_write not in lines_written:
            fp.write('<html><head></head><body><table>\n')
            fp.write(line_to_write)
            fp.write('</table></body></html>')
        fp.close()

    else:

        fp = open(cur_html_path, 'w')

        child_path = os.path.join(root_dir, f'{child_dir}.html')
        line_to_write = f'<tr><td><a href=\"{child_path}\">{child_dir}</a></td></tr>\n'

        fp.write('<html><head></head><body><table>\n')
        fp.write(line_to_write)
        fp.write('</table></body></html>')

        fp.close()

    child_path = '/'.join(html_path.split('/')[1:])
    save_dir = os.path.join(save_dir, root_dir)
    create_recursive_html_link(child_path, save_dir)


def visualize_attention_maps(attention_maps_all, save_dir, width, height, tokens_vis):
    r"""Function to visualize attention maps.
    Args:
        save_dir (str): Path to save attention maps
        batch_size (int): Batch size
        sampler_order (int): Sampler order
    """

    rand_name = list(attention_maps_all.keys())[0]
    nsteps = len(attention_maps_all[rand_name])
    hw_ori = width * height

    # html_path = save_dir + '.html'
    text_input = save_dir.split('/')[-1]
    # f = open(html_path, 'wt')

    all_html_paths = []

    for step_num in range(0, nsteps, 5):

        # if cond_id == 'cond':
        #     attention_maps_cur = attention_maps_cond[step_num]
        # else:
        #     attention_maps_cur = attention_maps_uncond[step_num]

        attention_maps = dict()

        for layer in attention_maps_all.keys():

            attention_ind = attention_maps_all[layer][step_num].cpu()

            # Attention maps are of shape [batch_size, nkeys, 77]
            # since they are averaged out while collecting from hooks to save memory.
            # Now split the heads from batch dimension
            bs, hw, nclip = attention_ind.shape
            down_ratio = np.sqrt(hw_ori // hw)
            width_cur = int(width // down_ratio)
            height_cur = int(height // down_ratio)
            attention_ind = attention_ind.reshape(
                bs, height_cur, width_cur, nclip)

            attention_maps[layer] = attention_ind

        # Obtain heatmaps corresponding to random heads and individual heads

        html_names = save_attention_heatmaps(
            attention_maps, tokens_vis, save_dir=save_dir, prefix='step_{}/attention_maps_cond'.format(
                step_num)
        )

        # Write the logic for recursively creating pages
        for html_name_cur in html_names:
            all_html_paths.append(os.path.join(text_input, html_name_cur))

    save_dir_root = '/'.join(save_dir.split('/')[0:-1])
    for html_pth in all_html_paths:
        create_recursive_html_link(html_pth, save_dir_root)


def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
    for i, attn_map in enumerate(atten_map_list):
        n_obj = len(attn_map)
        plt.figure()
        plt.clf()

        fig, axs = plt.subplots(
            ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))

        fig.set_figheight(3)
        fig.set_figwidth(3*n_obj+0.1)

        cmap = plt.get_cmap('YlOrRd')

        vmax = 0
        vmin = 1
        for tid in range(n_obj):
            attention_map_cur = attn_map[tid]
            vmax = max(vmax, float(attention_map_cur.max()))
            vmin = min(vmin, float(attention_map_cur.min()))

        for tid in range(n_obj):
            sns.heatmap(
                attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
                cmap=cmap, vmin=vmin, vmax=vmax
            )
            axs[tid].set_axis_off()

            if tokens_vis is not None:
                if tid == n_obj-1:
                    axs_xlabel = 'other tokens'
                else:
                    axs_xlabel = ''
                    for token_id in obj_tokens[tid]:
                        axs_xlabel += ' ' + tokens_vis[token_id.item() -
                                                       1][:-len('</w>')]
                axs[tid].set_title(axs_xlabel)

        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        fig.colorbar(sm, cax=axs[-1])

        fig.tight_layout()

        canvas = fig.canvas
        canvas.draw()
        width, height = canvas.get_width_height()
        img = np.frombuffer(canvas.tostring_rgb(),
                            dtype='uint8').reshape((height, width, 3))
        plt.savefig(os.path.join(
            save_dir, 'average_seed%d_attn%d.jpg' % (seed, i)), dpi=100)
        plt.close('all')
    return img


def get_average_attention_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
                               preprocess=False):
    r"""Function to visualize attention maps.
    Args:
        save_dir (str): Path to save attention maps
        batch_size (int): Batch size
        sampler_order (int): Sampler order
    """

    # Split attention maps over steps
    attention_maps_cond, _ = split_attention_maps_over_steps(
        attention_maps
    )

    nsteps = len(attention_maps_cond)
    hw_ori = width * height

    attention_maps = []
    for obj_token in obj_tokens:
        attention_maps.append([])

    for step_num in range(nsteps):
        attention_maps_cur = attention_maps_cond[step_num]

        for layer in attention_maps_cur.keys():
            if step_num < 10 or layer not in CrossAttentionLayers:
                continue

            attention_ind = attention_maps_cur[layer].cpu()

            # Attention maps are of shape [batch_size, nkeys, 77]
            # since they are averaged out while collecting from hooks to save memory.
            # Now split the heads from batch dimension
            bs, hw, nclip = attention_ind.shape
            down_ratio = np.sqrt(hw_ori // hw)
            width_cur = int(width // down_ratio)
            height_cur = int(height // down_ratio)
            attention_ind = attention_ind.reshape(
                bs, height_cur, width_cur, nclip)
            for obj_id, obj_token in enumerate(obj_tokens):
                if obj_token[0] == -1:
                    attention_map_prev = torch.stack(
                        [attention_maps[i][-1] for i in range(obj_id)]).sum(0)
                    attention_maps[obj_id].append(
                        attention_map_prev.max()-attention_map_prev)
                else:
                    obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
                        0].permute([3, 0, 1, 2])
                    # obj_attention_map = attention_ind[:, :, :, obj_token].mean(-1, True).permute([3, 0, 1, 2])
                    obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
                                                                                 interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
                    attention_maps[obj_id].append(obj_attention_map)

    attention_maps_averaged = []
    for obj_id, obj_token in enumerate(obj_tokens):
        if obj_id == len(obj_tokens) - 1:
            attention_maps_averaged.append(
                torch.cat(attention_maps[obj_id]).mean(0))
        else:
            attention_maps_averaged.append(
                torch.cat(attention_maps[obj_id]).mean(0))

    attention_maps_averaged_normalized = []
    attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
    for obj_id, obj_token in enumerate(obj_tokens):
        attention_maps_averaged_normalized.append(
            attention_maps_averaged[obj_id]/attention_maps_averaged_sum)

    if obj_tokens[-1][0] != -1:
        attention_maps_averaged_normalized = (
            torch.cat(attention_maps_averaged)/0.001).softmax(0)
        attention_maps_averaged_normalized = [
            attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]

    if preprocess:
        selem = square(5)
        selem = square(3)
        selem = square(1)
        attention_maps_averaged_eroded = [erosion(skimage.img_as_float(
            map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]]
        attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze(
            0)/255. > 0.8).float() for map in attention_maps_averaged_eroded]
        attention_maps_averaged_eroded.append(
            1 - torch.cat(attention_maps_averaged_eroded).sum(0, True))
        plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized,
                            attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis)
        attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat(
            [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded]
        return attention_maps_averaged_eroded
    else:
        plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
                            obj_tokens, save_dir, seed, tokens_vis)
        attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
            [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
        return attention_maps_averaged_normalized


def get_average_attention_maps_threshold(attention_maps, save_dir, width, height, obj_tokens, seed=0, threshold=0.02):
    r"""Function to visualize attention maps.
    Args:
        save_dir (str): Path to save attention maps
        batch_size (int): Batch size
        sampler_order (int): Sampler order
    """

    _EPS = 1e-8
    # Split attention maps over steps
    attention_maps_cond, _ = split_attention_maps_over_steps(
        attention_maps
    )

    nsteps = len(attention_maps_cond)
    hw_ori = width * height

    attention_maps = []
    for obj_token in obj_tokens:
        attention_maps.append([])

    # for each side prompt, get attention maps for all steps and all layers
    for step_num in range(nsteps):
        attention_maps_cur = attention_maps_cond[step_num]
        for layer in attention_maps_cur.keys():
            attention_ind = attention_maps_cur[layer].cpu()
            bs, hw, nclip = attention_ind.shape
            down_ratio = np.sqrt(hw_ori // hw)
            width_cur = int(width // down_ratio)
            height_cur = int(height // down_ratio)
            attention_ind = attention_ind.reshape(
                bs, height_cur, width_cur, nclip)
            for obj_id, obj_token in enumerate(obj_tokens):
                if attention_ind.shape[1] > width//2:
                    continue
                if obj_token[0] != -1:
                    obj_attention_map = attention_ind[:, :, :,
                                                      obj_token].mean(-1, True).permute([3, 0, 1, 2])
                    obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
                                                                                 interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
                    attention_maps[obj_id].append(obj_attention_map)

    # average of all steps and layers, thresholding
    attention_maps_thres = []
    attention_maps_averaged = []
    for obj_id, obj_token in enumerate(obj_tokens):
        if obj_token[0] != -1:
            average_map = torch.cat(attention_maps[obj_id]).mean(0)
            attention_maps_averaged.append(average_map)
            attention_maps_thres.append((average_map > threshold).float())

    # get the remaining region except for the original prompt
    attention_maps_averaged_normalized = []
    attention_maps_averaged_sum = torch.cat(attention_maps_thres).sum(0) + _EPS
    for obj_id, obj_token in enumerate(obj_tokens):
        if obj_token[0] != -1:
            attention_maps_averaged_normalized.append(
                attention_maps_thres[obj_id]/attention_maps_averaged_sum)
        else:
            attention_map_prev = torch.stack(
                attention_maps_averaged_normalized).sum(0)
            attention_maps_averaged_normalized.append(1.-attention_map_prev)

    plot_attention_maps(
        [attention_maps_averaged, attention_maps_averaged_normalized], save_dir, seed)

    attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
        [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
    # attention_maps_averaged_normalized = attention_maps_averaged_normalized.unsqueeze(1).repeat([1, 4, 1, 1]).cuda()
    return attention_maps_averaged_normalized


def get_token_maps(selfattn_maps, crossattn_maps, n_maps, save_dir, width, height, obj_tokens, kmeans_seed=0, tokens_vis=None,
                   preprocess=False, segment_threshold=0.3, num_segments=5, return_vis=False, save_attn=False):
    r"""Function to visualize attention maps.
    Args:
        save_dir (str): Path to save attention maps
        batch_size (int): Batch size
        sampler_order (int): Sampler order
    """

    resolution = 32
    # attn_maps_1024 = [attn_map for attn_map in selfattn_maps.values(
    # ) if attn_map.shape[1] == resolution**2]
    # attn_maps_1024 = torch.cat(attn_maps_1024).mean(0).cpu().numpy()
    attn_maps_1024 = {8: [], 16: [], 32: [], 64: []}
    for attn_map in selfattn_maps.values():
        resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
        if resolution_map != resolution:
            continue
        # attn_map = torch.nn.functional.interpolate(rearrange(attn_map, '1 c (h w) -> 1 c h w', h=resolution_map), (resolution, resolution),
        #                                            mode='bicubic', antialias=True)
        # attn_map = rearrange(attn_map, '1 (h w) a b -> 1 (a b) h w', h=resolution_map)
        attn_map = attn_map.reshape(
            1, resolution_map, resolution_map, resolution_map**2).permute([3, 0, 1, 2]).float()
        attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
                                                   mode='bicubic', antialias=True)
        attn_maps_1024[resolution_map].append(attn_map.permute([1, 2, 3, 0]).reshape(
            1, resolution**2, resolution_map**2))
    attn_maps_1024 = torch.cat([torch.cat(v).mean(0).cpu()
                                for v in attn_maps_1024.values() if len(v) > 0], -1).numpy()
    if save_attn:
        print('saving self-attention maps...', attn_maps_1024.shape)
        torch.save(torch.from_numpy(attn_maps_1024),
                   'results/maps/selfattn_maps.pth')
    seed_everything(kmeans_seed)
    # import ipdb;ipdb.set_trace()
    # kmeans = KMeans(n_clusters=num_segments,
    #                 n_init=10).fit(attn_maps_1024)
    # clusters = kmeans.labels_
    # clusters = clusters.reshape(resolution, resolution)
    # mesh = np.array(np.meshgrid(range(resolution), range(resolution), indexing='ij'), dtype=np.float32)/resolution
    # dists = mesh.reshape(2, -1).T
    # delta = 0.01
    # spatial_sim = rbf_kernel(dists, dists)*delta
    sc = SpectralClustering(num_segments, affinity='precomputed', n_init=100,
                            assign_labels='kmeans')
    clusters = sc.fit_predict(attn_maps_1024)
    clusters = clusters.reshape(resolution, resolution)
    fig = plt.figure()
    plt.imshow(clusters)
    plt.axis('off')
    plt.savefig(os.path.join(save_dir, 'segmentation_k%d_seed%d.jpg' % (num_segments, kmeans_seed)),
                bbox_inches='tight', pad_inches=0)
    if return_vis:
        canvas = fig.canvas
        canvas.draw()
        cav_width, cav_height = canvas.get_width_height()
        segments_vis = np.frombuffer(canvas.tostring_rgb(),
                                     dtype='uint8').reshape((cav_height, cav_width, 3))

    plt.close()

    # label the segmentation mask using cross-attention maps
    cross_attn_maps_1024 = []
    for attn_map in crossattn_maps.values():
        resolution_map = np.sqrt(attn_map.shape[1]).astype(int)
        # if resolution_map != 16:
        # continue
        attn_map = attn_map.reshape(
            1, resolution_map, resolution_map, -1).permute([0, 3, 1, 2]).float()
        attn_map = torch.nn.functional.interpolate(attn_map, (resolution, resolution),
                                                   mode='bicubic', antialias=True)
        cross_attn_maps_1024.append(attn_map.permute([0, 2, 3, 1]))

    cross_attn_maps_1024 = torch.cat(
        cross_attn_maps_1024).mean(0).cpu().numpy()
    normalized_span_maps = []
    for token_ids in obj_tokens:
        token_ids = torch.clip(token_ids, 0, 76)
        span_token_maps = cross_attn_maps_1024[:, :, token_ids.numpy()]
        normalized_span_map = np.zeros_like(span_token_maps)
        for i in range(span_token_maps.shape[-1]):
            curr_noun_map = span_token_maps[:, :, i]
            normalized_span_map[:, :, i] = (
                # curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
                curr_noun_map - np.abs(curr_noun_map.min())) / (curr_noun_map.max()-curr_noun_map.min())
        normalized_span_maps.append(normalized_span_map)
    foreground_token_maps = [np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze(
    ) for normalized_span_map in normalized_span_maps]
    background_map = np.zeros([clusters.shape[0], clusters.shape[1]]).squeeze()
    for c in range(num_segments):
        cluster_mask = np.zeros_like(clusters)
        cluster_mask[clusters == c] = 1.
        is_foreground = False
        for normalized_span_map, foreground_nouns_map, token_ids in zip(normalized_span_maps, foreground_token_maps, obj_tokens):
            score_maps = [cluster_mask * normalized_span_map[:, :, i]
                          for i in range(len(token_ids))]
            scores = [score_map.sum() / cluster_mask.sum()
                      for score_map in score_maps]
            if max(scores) > segment_threshold:
                foreground_nouns_map += cluster_mask
                is_foreground = True
        if not is_foreground:
            background_map += cluster_mask
    foreground_token_maps.append(background_map)

    # resize the token maps and visualization
    resized_token_maps = torch.cat([torch.nn.functional.interpolate(torch.from_numpy(token_map).unsqueeze(0).unsqueeze(
        0), (height, width), mode='bicubic', antialias=True)[0] for token_map in foreground_token_maps]).clamp(0, 1)

    resized_token_maps = resized_token_maps / \
        (resized_token_maps.sum(0, True)+1e-8)
    resized_token_maps = [token_map.unsqueeze(
        0) for token_map in resized_token_maps]
    foreground_token_maps = [token_map[None, :, :]
                             for token_map in foreground_token_maps]
    if preprocess:
        selem = square(5)
        eroded_token_maps = torch.stack([torch.from_numpy(erosion(skimage.img_as_float(
            map[0].numpy()*255), selem))/255. for map in resized_token_maps[:-1]]).clamp(0, 1)
        # import ipdb; ipdb.set_trace()
        eroded_background_maps = (1-eroded_token_maps.sum(0, True)).clamp(0, 1)
        eroded_token_maps = torch.cat([eroded_token_maps, eroded_background_maps])
        eroded_token_maps = eroded_token_maps / (eroded_token_maps.sum(0, True)+1e-8)
        resized_token_maps = [token_map.unsqueeze(
            0) for token_map in eroded_token_maps]

    token_maps_vis = plot_attention_maps([foreground_token_maps, resized_token_maps], obj_tokens,
                                         save_dir, kmeans_seed, tokens_vis)
    resized_token_maps = [token_map.unsqueeze(1).repeat(
        [1, 4, 1, 1]).to(attn_map.dtype).cuda() for token_map in resized_token_maps]
    if return_vis:
        return resized_token_maps, segments_vis, token_maps_vis
    else:
        return resized_token_maps