"""Functions to help with searching codes using regex."""

import pickle
import re

import numpy as np
import torch
from tqdm import tqdm


def load_dataset_cache(cache_base_path):
    """Load cache files required for dataset from `cache_base_path`."""
    tokens_str = np.load(cache_base_path + "tokens_str.npy")
    tokens_text = np.load(cache_base_path + "tokens_text.npy")
    token_byte_pos = np.load(cache_base_path + "token_byte_pos.npy")
    return tokens_str, tokens_text, token_byte_pos


def load_code_search_cache(cache_base_path):
    """Load cache files required for code search from `cache_base_path`."""
    metrics = np.load(cache_base_path + "metrics.npy", allow_pickle=True).item()
    with open(cache_base_path + "cb_acts.pkl", "rb") as f:
        cb_acts = pickle.load(f)
    with open(cache_base_path + "act_count_ft_tkns.pkl", "rb") as f:
        act_count_ft_tkns = pickle.load(f)

    return cb_acts, act_count_ft_tkns, metrics


def search_re(re_pattern, tokens_text, at_odd_even=-1):
    """Get list of (example_id, token_pos) where re_pattern matches in tokens_text.

    Args:
        re_pattern: regex pattern to search for.
        tokens_text: list of example texts.
        at_odd_even: to limit matches to odd or even positions only.
            -1 (default): to not limit matches.
            0: to limit matches to odd positions only.
            1: to limit matches to even positions only.
            This is useful for the TokFSM dataset when searching for states
            since the first token of states are always at even positions.
    """
    # TODO: ensure that parentheses are not escaped
    assert at_odd_even in [-1, 0, 1], f"Invalid at_odd_even: {at_odd_even}"
    if re_pattern.find("(") == -1:
        re_pattern = f"({re_pattern})"
    res = [
        (i, finditer.span(1)[0])
        for i, text in enumerate(tokens_text)
        for finditer in re.finditer(re_pattern, text)
        if finditer.span(1)[0] != finditer.span(1)[1]
    ]
    if at_odd_even != -1:
        res = [r for r in res if r[1] % 2 == at_odd_even]
    return res


def byte_id_to_token_pos_id(example_byte_id, token_byte_pos):
    """Convert byte position (or character position in a text) to its token position.

    Used to convert the searched regex span to its token position.

    Args:
        example_byte_id: tuple of (example_id, byte_id) where byte_id is a
            character's position in the text.
        token_byte_pos: numpy array of shape (num_examples, seq_len) where
            `token_byte_pos[example_id][token_pos]` is the byte position of
            the token at `token_pos` in the example with `example_id`.

    Returns:
        (example_id, token_pos_id) tuple.
    """
    example_id, byte_id = example_byte_id
    index = np.searchsorted(token_byte_pos[example_id], byte_id, side="right")
    return (example_id, index)


def get_code_precision_and_recall(token_pos_ids, codebook_acts, cb_act_counts=None):
    """Search for the codes that activate on the given `token_pos_ids`.

    Args:
        token_pos_ids: list of (example_id, token_pos_id) tuples.
        codebook_acts: numpy array of activations of a codebook on a dataset with
            shape (num_examples, seq_len, k_codebook).
        cb_act_counts: array of shape (num_codes,) where `cb_act_counts[cb_name][code]`
            is the number of times the code `code` is activated in the dataset.

    Returns:
        codes: numpy array of code ids sorted by their precision on the given `token_pos_ids`.
        prec: numpy array where `prec[i]` is the precision of the code
            `codes[i]` for the given `token_pos_ids`.
        recall: numpy array where `recall[i]` is the recall of the code
            `codes[i]` for the given `token_pos_ids`.
        code_acts: numpy array where `code_acts[i]` is the number of times
            the code `codes[i]` is activated in the dataset.
    """
    codes = np.array(
        [
            codebook_acts[example_id][token_pos_id]
            for example_id, token_pos_id in token_pos_ids
        ]
    )
    codes, counts = np.unique(codes, return_counts=True)
    recall = counts / len(token_pos_ids)
    idx = recall > 0.01
    codes, counts, recall = codes[idx], counts[idx], recall[idx]
    if cb_act_counts is not None:
        code_acts = np.array([cb_act_counts[code] for code in codes])
        prec = counts / code_acts
        sort_idx = np.argsort(prec)[::-1]
    else:
        code_acts = np.zeros_like(codes)
        prec = np.zeros_like(codes)
        sort_idx = np.argsort(recall)[::-1]
    codes, prec, recall = codes[sort_idx], prec[sort_idx], recall[sort_idx]
    code_acts = code_acts[sort_idx]
    return codes, prec, recall, code_acts


def get_neuron_precision_and_recall(
    token_pos_ids, recall, neuron_acts_by_ex, neuron_sorted_acts
):
    """Get the neurons with the highest precision and recall for the given `token_pos_ids`.

    Args:
        token_pos_ids: list of token (example_id, token_pos_id) tuples from a dataset over which
            the neurons with the highest precision and recall are to be found.
        recall: recall threshold for the neurons (this determines their activation threshold).
        neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
            on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
            The third dimension is 2 because we consider neurons from both: attention and mlp.
        neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
            on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
            This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
            dimensions to the last dimensions and then sorting the last dimension.

    Returns:
        best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
        best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
            based on the threshold determined by the `recall` argument.
        best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
            `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
            and `neuron_id` is the neuron's index in the layer.
    """
    if isinstance(neuron_acts_by_ex, torch.Tensor):
        neuron_acts_on_pattern = torch.stack(
            [
                neuron_acts_by_ex[example_id, token_pos_id]
                for example_id, token_pos_id in token_pos_ids
            ],
            dim=-1,
        )  # (layers, 2, dim_size, matches)
        neuron_acts_on_pattern = torch.sort(neuron_acts_on_pattern, dim=-1).values
    else:
        neuron_acts_on_pattern = np.stack(
            [
                neuron_acts_by_ex[example_id, token_pos_id]
                for example_id, token_pos_id in token_pos_ids
            ],
            axis=-1,
        )  # (layers, 2, dim_size, matches)
        neuron_acts_on_pattern.sort(axis=-1)
        neuron_acts_on_pattern = torch.from_numpy(neuron_acts_on_pattern)
    act_thresh = neuron_acts_on_pattern[
        :, :, :, -int(recall * neuron_acts_on_pattern.shape[-1])
    ]
    assert neuron_sorted_acts.shape[:-1] == act_thresh.shape
    prec_den = torch.searchsorted(neuron_sorted_acts, act_thresh.unsqueeze(-1))
    prec_den = prec_den.squeeze(-1)
    prec_den = neuron_sorted_acts.shape[-1] - prec_den
    prec = int(recall * neuron_acts_on_pattern.shape[-1]) / prec_den
    assert (
        prec.shape == neuron_acts_on_pattern.shape[:-1]
    ), f"{prec.shape} != {neuron_acts_on_pattern.shape[:-1]}"

    best_neuron_idx = np.unravel_index(prec.argmax(), prec.shape)
    best_prec = prec[best_neuron_idx]
    best_neuron_act_thresh = act_thresh[best_neuron_idx].item()
    best_neuron_acts = neuron_acts_by_ex[
        :, :, best_neuron_idx[0], best_neuron_idx[1], best_neuron_idx[2]
    ]
    best_neuron_acts = best_neuron_acts >= best_neuron_act_thresh
    best_neuron_acts = np.stack(np.where(best_neuron_acts), axis=-1)

    return best_prec, best_neuron_acts, best_neuron_idx


def convert_to_adv_name(name, cb_at, gcb=""):
    """Convert layer0_head0 to layer0_attn_preproj_gcb0."""
    if gcb:
        layer, head = name.split("_")
        return layer + f"_{cb_at}_gcb" + head[4:]
    else:
        return layer + "_" + cb_at


def convert_to_base_name(name, gcb=""):
    """Convert layer0_attn_preproj_gcb0 to layer0_head0."""
    split_name = name.split("_")
    layer, head = split_name[0], split_name[-1][3:]
    if "gcb" in name:
        return layer + "_head" + head
    else:
        return layer


def get_layer_head_from_base_name(name):
    """Convert layer0_head0 to 0, 0."""
    split_name = name.split("_")
    layer = int(split_name[0][5:])
    head = None
    if len(split_name) > 1:
        head = int(split_name[-1][4:])
    return layer, head


def get_layer_head_from_adv_name(name):
    """Convert layer0_attn_preproj_gcb0 to 0, 0."""
    base_name = convert_to_base_name(name)
    layer, head = get_layer_head_from_base_name(base_name)
    return layer, head


def get_codes_from_pattern(
    re_pattern,
    tokens_text,
    token_byte_pos,
    cb_acts,
    act_count_ft_tkns,
    gcb="",
    topk=5,
    prec_threshold=0.5,
    at_odd_even=-1,
):
    """Fetch codes that activate on a given regex pattern.

    Retrieves at most `top_k` codes that activate with precision above `prec_threshold`.

    Args:
        re_pattern: regex pattern to search for.
        tokens_text: list of example texts of a dataset.
        token_byte_pos: numpy array of shape (num_examples, seq_len) where
            `token_byte_pos[example_id][token_pos]` is the byte position of
            the token at `token_pos` in the example with `example_id`.
        cb_acts: dict of codebook activations.
        act_count_ft_tkns: dict over all codebooks of number of token activations on the dataset
        gcb: "_gcb" for grouped codebooks and "" for non-grouped codebooks.
        topk: maximum number of codes to return per codebook.
        prec_threshold: minimum precision required for a code to be returned.
        at_odd_even: to limit matches to odd or even positions only.
            -1 (default): to not limit matches.
            0: to limit matches to odd positions only.
            1: to limit matches to even positions only.
            This is useful for the TokFSM dataset when searching for states
            since the first token of states are always at even positions.

    Returns:
        codebook_wise_codes: dict of codebook name to list of
        (code, prec, recall, code_acts) tuples.
        re_token_matches: number of tokens that match the regex pattern.
    """
    byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
    token_pos_ids = [
        byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
    ]
    token_pos_ids = np.unique(token_pos_ids, axis=0)
    re_token_matches = len(token_pos_ids)
    codebook_wise_codes = {}
    for cb_name, cb in tqdm(cb_acts.items()):
        base_cb_name = convert_to_base_name(cb_name, gcb=gcb)
        codes, prec, recall, code_acts = get_code_precision_and_recall(
            token_pos_ids,
            cb,
            cb_act_counts=act_count_ft_tkns[base_cb_name],
        )
        idx = np.arange(min(topk, len(codes)))
        idx = idx[prec[:topk] > prec_threshold]
        codes, prec, recall = codes[idx], prec[idx], recall[idx]
        code_acts = code_acts[idx]
        codes_pr = list(zip(codes, prec, recall, code_acts))
        codebook_wise_codes[base_cb_name] = codes_pr
    return codebook_wise_codes, re_token_matches


def get_neurons_from_pattern(
    re_pattern,
    tokens_text,
    token_byte_pos,
    neuron_acts_by_ex,
    neuron_sorted_acts,
    recall_threshold,
    at_odd_even=-1,
):
    """Fetch the highest precision neurons that activate on a given regex pattern.

    The activation threshold for the neurons is determined by the `recall_threshold`.

    Args:
        re_pattern: regex pattern to search for.
        tokens_text: list of example texts of a dataset.
        token_byte_pos: numpy array of shape (num_examples, seq_len) where
            `token_byte_pos[example_id][token_pos]` is the byte position of
            the token at `token_pos` in the example with `example_id`.
        neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
            on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
            The third dimension is 2 because we consider neurons from both: attention and mlp.
        neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
            on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
            This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
            dimensions to the last dimensions and then sorting the last dimension.
        recall_threshold: recall threshold for the neurons (this determines their activation threshold).
        at_odd_even: to limit matches to odd or even positions only.
            -1 (default): to not limit matches.
            0: to limit matches to odd positions only.
            1: to limit matches to even positions only.
            This is useful for the TokFSM dataset when searching for states
            since the first token of states are always at even positions.

    Returns:
        best_prec: highest precision amongst all the neurons for the given `token_pos_ids`.
        best_neuron_acts: number of activations of the best neuron for the given `token_pos_ids`
            based on the threshold determined by the `recall` argument.
        best_neuron_idx: tuple of (layer, is_mlp, neuron_id) where `layer` is the layer number,
            `is_mlp` is 0 if the neuron is from attention and 1 if the neuron is from mlp,
            and `neuron_id` is the neuron's index in the layer.
        re_token_matches: number of tokens that match the regex pattern.
    """
    byte_ids = search_re(re_pattern, tokens_text, at_odd_even=at_odd_even)
    token_pos_ids = [
        byte_id_to_token_pos_id(ex_byte_id, token_byte_pos) for ex_byte_id in byte_ids
    ]
    token_pos_ids = np.unique(token_pos_ids, axis=0)
    re_token_matches = len(token_pos_ids)
    best_prec, best_neuron_acts, best_neuron_idx = get_neuron_precision_and_recall(
        token_pos_ids,
        recall_threshold,
        neuron_acts_by_ex,
        neuron_sorted_acts,
    )
    return best_prec, best_neuron_acts, best_neuron_idx, re_token_matches


def compare_codes_with_neurons(
    best_codes_info,
    tokens_text,
    token_byte_pos,
    neuron_acts_by_ex,
    neuron_sorted_acts,
    at_odd_even=-1,
):
    """Compare codes with the highest precision neurons on the regex pattern of the code.

    Args:
        best_codes_info: list of CodeInfo objects.
        tokens_text: list of example texts of a dataset.
        token_byte_pos: numpy array of shape (num_examples, seq_len) where
            `token_byte_pos[example_id][token_pos]` is the byte position of
            the token at `token_pos` in the example with `example_id`.
        neuron_acts_by_ex: numpy array of activations of all the attention and mlp output neurons
            on a dataset with shape (num_examples, seq_len, num_layers, 2, dim_size).
            The third dimension is 2 because we consider neurons from both: attention and mlp.
        neuron_sorted_acts: numpy array of sorted activations of all the attention and mlp output neurons
            on a dataset with shape (num_layers, 2, dim_size, num_examples * seq_len).
            This should be obtained using the `neuron_acts_by_ex` array by rearranging the first two
            dimensions to the last dimensions and then sorting the last dimension.
        at_odd_even: to limit matches to odd or even positions only.
            -1 (default): to not limit matches.
            0: to limit matches to odd positions only.
            1: to limit matches to even positions only.
            This is useful for the TokFSM dataset when searching for states
            since the first token of states are always at even positions.

    Returns:
        codes_better_than_neurons: fraction of codes that have higher precision than the highest
            precision neuron on the regex pattern of the code.
        code_best_precs: is an array of the precision of each code in `best_codes_info`.
        all_best_prec: is an array of the highest precision neurons on the regex pattern.
    """
    assert isinstance(neuron_acts_by_ex, np.ndarray)
    (
        neuron_best_prec,
        all_best_neuron_acts,
        all_best_neuron_idxs,
        all_re_token_matches,
    ) = zip(
        *[
            get_neurons_from_pattern(
                code_info.regex,
                tokens_text,
                token_byte_pos,
                neuron_acts_by_ex,
                neuron_sorted_acts,
                code_info.recall,
                at_odd_even=at_odd_even,
            )
            for code_info in tqdm(best_codes_info)
        ],
        strict=True,
    )
    neuron_best_prec = np.array(neuron_best_prec)
    code_best_precs = np.array([code_info.prec for code_info in best_codes_info])
    codes_better_than_neurons = code_best_precs > neuron_best_prec
    return codes_better_than_neurons.mean(), code_best_precs, neuron_best_prec