from typing import Dict, Sequence, Tuple
import re
import numpy as np
import torch


def postprocess_classification_generation(predictions) -> str:
    return re.split("Prompt|Completion", predictions, 1)[0]


def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
    """Compute the accuracy of a sequence of predictions."""

    def _preprocess_fn(s):
        """Function to preprocess both targets and predictions."""
        return s.lower()

    is_correct = [
        _preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"])
        for x in predictions
    ]

    return np.mean(is_correct).item()


def compute_shifted_logits_and_labels(
    logits: torch.Tensor, encodings, tokenizer, eoc_token_id
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Helper function to compute shifted logits and labels.

    This allows for straightforward computation of the loss on shift_logits
    and shift_labels such that the nth element of logits computes the n-1th
    element of the original labels (in the outputs, the nth element of logits
    corresponds to the nth element of the labels).

    Elements in shift_labels that correspond to inputs are masked with values
    of -100 (by default in hf, loss is only computed on token IDs >= 0).

    Returns: tuple containing two elements:
        shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
        shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
    """

    labels = encodings["input_ids"].clone()

    # convert padding and EOC tokens to -100 so they are ignored in loss
    labels[labels == tokenizer.pad_token_id] = -100
    labels[labels == eoc_token_id] = -100

    # Convert all tokens in prefix until separator to -100 so they are
    # ignored in loss
    for idx in range(len(labels)):
        # Find the location of the last token of prefix *from right*,
        # since the first non-padding token of the sequence will also be
        # eos_token (because bos_token and eos_token are the same for
        # the tokenizer).
        end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
        labels[idx, : end_of_prefix + 1] = -100

    # Shift so that tokens < n predict n. The shifted tensors both have
    # shape [batch_size, seq_len - 1].
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    return shift_logits, shift_labels


def compute_per_sample_probs(
    encodings, tokenizer, logits: torch.Tensor, eoc_token_id
) -> torch.Tensor:
    """Helper function to compute per-sample probability of the input sequence.

    Assumes <eos token> is used to separate inputs from targets in the
    prompt text
    """
    shift_logits, shift_labels = compute_shifted_logits_and_labels(
        logits, encodings, tokenizer, eoc_token_id
    )

    # Tuple of tensors for unmasked label tokens. The first element of the
    # tuple contains the batch indices; the second element contains the
    # sequence indices.
    unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
    # Tensor where the i^th element is the token_id corresponding to the i^th
    # element of unmasked_indices
    unmasked_token_ids = shift_labels[unmasked_indices]

    # 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
    target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
    target_idxs = target_idxs.to(shift_logits.device)

    # Sanity check that every element in batch has at least one unmasked
    # target token
    assert torch.all(
        torch.bincount(target_idxs[:, 0]) != 0
    ), "At least one element in batch has no unmasked target tokens."

    # Renormalize over tokens to make sure they are proper probabilities via
    # softmax over the token dimension.
    shift_probs = torch.nn.functional.softmax(shift_logits, 2)

    # Compute the probability of the target sequence (as the product of the
    # probability of the individual tokens in the sequence).
    target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
    for i, j, k in target_idxs:
        target_probs[i] *= shift_probs[i, j, k]

    return target_probs


def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
    """Helper function to compute per-sample classification loss.

    Assumes <eos token> is used to separate inputs from targets in the
    prompt text
    """
    shift_logits, shift_labels = compute_shifted_logits_and_labels(
        logits, encodings, tokenizer, eoc_token_id
    )

    device = shift_logits.device

    # Loss is computed token-wise, on Tensors of shape
    # [batch_size * (seq_len - 1), vocab_size]
    # and returns a loss tensor of shape
    # [batch_size * (seq_len - 1)]. Most of the tokens will be masked
    # in this computation.
    loss = torch.nn.functional.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1).to(device),
        reduction="none",
    )

    # Reshape to [batch_size, seq_len - 1]
    loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()

    # loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
    # that should be ignored in the loss.
    loss_mask = (shift_labels != -100).int().cpu()

    loss *= loss_mask

    # Compute per-element loss : sum loss over all (unmasked) tokens and
    # divide by number of variable tokens to obtain tensor of
    # shape [batch_size,]
    loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
    return loss