# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from typing import Any, Literal, Optional

import torch
# import torch._dynamo.config
# import torch._inductor.config

from litgpt.model import GPT
from utils.snac_utils import layershift, snac_config
from tqdm import tqdm


def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
    if torch._dynamo.is_compiling():
        # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
        distribution = torch.empty_like(probs).exponential_(1)
        return torch.argmax(probs / distribution, dim=-1, keepdim=True)
    return torch.multinomial(probs, num_samples=1)


def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=False)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
    # Example:
    # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
    # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
    sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
    # Keep at least 1 token always to prevent the case where no token is selected
    # In this case the most probable one is always kept
    sorted_indices_to_remove[-1:] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(
        0, sorted_indices, sorted_indices_to_remove
    )
    logits = logits.masked_fill(indices_to_remove, float("-inf"))
    return logits


def sample(
    logits: torch.Tensor,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
) -> torch.Tensor:
    if top_p < 0.0 or top_p > 1.0:
        raise ValueError(f"top_p must be in [0, 1], got {top_p}")
    logits = logits[0, -1]
    # optionally crop the logits to only the top k options
    if top_k is not None:
        v, i = torch.topk(logits, min(top_k, logits.size(-1)))
        # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
        logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
    # optionally scale the logits and sample from a probability distribution
    if temperature > 0.0 or top_p > 0.0:
        if temperature > 0.0:
            logits = logits / temperature
        # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
        if top_p < 1.0:
            logits = sample_top_p(logits, top_p)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        return multinomial_num_samples_1(probs)
    return torch.argmax(logits, dim=-1, keepdim=True)


def next_token(
    model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
) -> torch.Tensor:
    input_pos = input_pos.to(model.device)
    logits_a, logit_t = model(x, input_pos)

    next_audio_tokens = []
    for logit_a in logits_a:
        next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
        next_audio_tokens.append(next_a)
    next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
    return next_audio_tokens, next_t


def next_token_asr(
    model: GPT,
    input_pos: torch.Tensor,
    audio_features: torch.tensor,
    lens: int,
    input_ids: list,
    **kwargs: Any,
) -> torch.Tensor:
    input_pos = input_pos.to(model.device)
    input_ids = [input_id.to(model.device) for input_id in input_ids]
    logits_a, logit_t = model(audio_features, input_ids, input_pos, whisper_lens=lens)

    next_audio_tokens = []
    for logit_a in logits_a:
        next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
        next_audio_tokens.append(next_a)
    next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
    return next_audio_tokens, next_t


def next_token_A1T2(
    model: GPT,
    audio_features: torch.tensor,
    input_ids: list,
    whisper_lens: int,
    task: list,
    input_pos: torch.Tensor,
    **kwargs: Any,
) -> torch.Tensor:
    input_pos = input_pos.to(model.device)
    input_ids = [input_id.to(model.device) for input_id in input_ids]
    logits_a, logit_t = model(
        audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
    )

    next_audio_tokens = []
    for logit_a in logits_a:
        next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
        next_audio_tokens.append(next_a)
    next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
    return next_audio_tokens, next_t


def next_token_A1T1(
    model: GPT,
    audio_features: torch.tensor,
    input_ids: list,
    whisper_lens: int,
    task: list,
    input_pos: torch.Tensor,
    **kwargs: Any,
) -> torch.Tensor:
    input_pos = input_pos.to(model.device)
    input_ids = [input_id.to(model.device) for input_id in input_ids]
    logits_a, logit_t = model(
        audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
    )
    next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
    return next_t


def next_token_batch(
    model: GPT,
    audio_features: torch.tensor,
    input_ids: list,
    whisper_lens: int,
    task: list,
    input_pos: torch.Tensor,
    **kwargs: Any,
) -> torch.Tensor:
    input_pos = input_pos.to(model.device)
    input_ids = [input_id.to(model.device) for input_id in input_ids]
    logits_a, logit_t = model(
        audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
    )

    for i in range(7):
        logits_a[i] = logits_a[i][0].unsqueeze(0)
    logit_t = logit_t[1].unsqueeze(0)

    next_audio_tokens = []
    for logit_a in logits_a:
        next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
        next_audio_tokens.append(next_a)
    next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
    return next_audio_tokens, next_t


# torch._dynamo.config.automatic_dynamic_shapes = True
# torch._inductor.config.triton.unique_kernel_names = True
# torch._inductor.config.coordinate_descent_tuning = True
# next_token = torch.compile(next_token, mode="reduce-overhead")


@torch.inference_mode()
def generate(
    model: GPT,
    input_ids: list,
    max_returned_tokens: int,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id: Optional[int] = None,
    shift: Optional[int] = None,
    include_prompt: bool = True,
    generate_text=False,
) -> torch.Tensor:
    # print("eos_id_a:", eos_id_a)
    # print("eos_id_t:", eos_id_t)
    # print("pad_id:", pad_id)
    """
    Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
    The implementation of this function is modified from A. Karpathy's nanoGPT.

    Args:
        model: The model to use.
        prompt: Tensor of shape (T) with indices of the prompt sequence.
        max_returned_tokens: The maximum number of tokens to return (given plus generated).
        temperature: Scales the predicted logits by 1 / temperature.
        top_k: If specified, only sample among the tokens with the k highest probabilities.
        top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
            In top-p sampling, the next token is sampled from the highest probability tokens
            whose cumulative probability exceeds the threshold `top_p`. When specified,
            it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
            to sampling the most probable token, while `top_p=1` samples from the whole distribution.
            It can be used in conjunction with `top_k` and `temperature` with the following order
            of application:

            1. `top_k` sampling
            2. `temperature` scaling
            3. `top_p` sampling

            For more details, see https://arxiv.org/abs/1904.09751
            or https://huyenchip.com/2024/01/16/sampling.html#top_p
        eos_id: If specified, stop generating any more token once the <eos> token is triggered.
        include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
    """
    T = input_ids[0].size(0)
    device = input_ids[0].device
    assert max_returned_tokens > T
    if model.max_seq_length < max_returned_tokens - 1:
        # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
        # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
        # not support it to avoid negatively impacting the overall speed
        raise NotImplementedError(
            f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
        )

    for input_id in input_ids:
        input_id = [input_id]
    (
        tokens_A1,
        tokens_A2,
        tokens_A3,
        tokens_A4,
        tokens_A5,
        tokens_A6,
        tokens_A7,
        tokens_T,
    ) = input_ids

    tokens_A1_output = [tokens_A1]
    tokens_A2_output = [tokens_A2]
    tokens_A3_output = [tokens_A3]
    tokens_A4_output = [tokens_A4]
    tokens_A5_output = [tokens_A5]
    tokens_A6_output = [tokens_A6]
    tokens_A7_output = [tokens_A7]
    tokens_T_output = [tokens_T]

    list_output = [
        tokens_A1_output,
        tokens_A2_output,
        tokens_A3_output,
        tokens_A4_output,
        tokens_A5_output,
        tokens_A6_output,
        tokens_A7_output,
        tokens_T_output,
    ]

    input_pos = torch.tensor([T], device=device)
    model_input_ids = [
        tokens_A1.view(1, -1),
        tokens_A2.view(1, -1),
        tokens_A3.view(1, -1),
        tokens_A4.view(1, -1),
        tokens_A5.view(1, -1),
        tokens_A6.view(1, -1),
        tokens_A7.view(1, -1),
        tokens_T.view(1, -1),
    ]

    tokens_A, token_T = next_token(
        model,
        torch.arange(0, T, device=device),
        model_input_ids,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )
    for i in range(7):
        list_output[i].append(tokens_A[i].clone())
    list_output[7].append(token_T.clone())

    # prepare the input for the next iteration
    for i in range(7):
        tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
    token_T = token_T.clone()

    text_end = False
    max_returned_tokens = 1000
    for _ in tqdm(range(2, max_returned_tokens - T + 1)):
        model_input_ids = [
            token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
        ] + [token_T.view(1, -1).to(torch.int32)]
        tokens_A, token_T = next_token(
            model,
            input_pos,
            model_input_ids,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )
        if text_end:
            token_T = torch.tensor([pad_id], device=device)

        for i in range(7):
            list_output[i].append(tokens_A[i].clone())
        list_output[7].append(token_T.clone())

        if tokens_A[-1] == eos_id_a:
            break
        if token_T == eos_id_t:
            if generate_text:
                break
            text_end = True

        for i in range(7):
            tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
        token_T = token_T.clone()
        input_pos = input_pos.add_(1)

    for i in range(len(list_output)):
        list_output[i] = torch.cat(list_output[i])
    return list_output


@torch.inference_mode()
def generate_TA_BATCH(
    model: GPT,
    audio_features: torch.Tensor,
    input_ids: list,
    leng,
    task,
    max_returned_tokens: int = 1000,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id_t: Optional[int] = None,
    shift: Optional[int] = None,
    include_prompt: bool = True,
    generate_text=False,
) -> torch.Tensor:

    T = input_ids[0].size(1)
    device = input_ids[0].device
    assert max_returned_tokens > T
    if model.max_seq_length < max_returned_tokens - 1:
        raise NotImplementedError(
            f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
        )

    input_pos = torch.tensor([T], device=device)
    model_input_ids = input_ids

    list_output = [[] for i in range(8)]

    tokens_A, token_T = next_token_batch(
        model,
        audio_features.to(torch.float32).to(model.device),
        input_ids,
        [T - 3, T - 3],
        ["A1T2", "A1T2"],
        input_pos=torch.arange(0, T, device=device),
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )

    for i in range(7):
        list_output[i].append(tokens_A[i].tolist()[0])
    list_output[7].append(token_T.tolist()[0])

    model_input_ids = [[] for i in range(8)]
    for i in range(7):
        tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
        model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
        model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
        model_input_ids[i] = torch.stack(model_input_ids[i])

    model_input_ids[-1].append(token_T.clone().to(torch.int32))
    model_input_ids[-1].append(token_T.clone().to(torch.int32))
    model_input_ids[-1] = torch.stack(model_input_ids[-1])

    text_end = False

    for _ in range(2, max_returned_tokens - T + 1):
        tokens_A, token_T = next_token_batch(
            model,
            None,
            model_input_ids,
            None,
            None,
            input_pos=input_pos,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )

        if text_end:
            token_T = torch.tensor([pad_id_t], device=device)

        if tokens_A[-1] == eos_id_a:
            break
        if token_T == eos_id_t:
            text_end = True

        for i in range(7):
            list_output[i].append(tokens_A[i].tolist()[0])
        list_output[7].append(token_T.tolist()[0])

        model_input_ids = [[] for i in range(8)]
        for i in range(7):
            tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
            model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
            model_input_ids[i].append(
                torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
            )
            model_input_ids[i] = torch.stack(model_input_ids[i])

        model_input_ids[-1].append(token_T.clone().to(torch.int32))
        model_input_ids[-1].append(token_T.clone().to(torch.int32))
        model_input_ids[-1] = torch.stack(model_input_ids[-1])

        input_pos = input_pos.add_(1)

    return list_output


@torch.inference_mode()
def generate_TT(
    model: GPT,
    audio_features: torch.Tensor,
    input_ids: list,
    leng,
    task,
    max_returned_tokens: int = 2048,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id_t: Optional[int] = None,
    shift: Optional[int] = None,
    include_prompt: bool = True,
    generate_text=False,
) -> torch.Tensor:

    T = input_ids[0].size(1)
    device = input_ids[0].device

    output = []
    token_T = next_token_A1T1(
        model,
        None,
        input_ids,
        None,
        None,
        input_pos=torch.arange(0, T, device=device),
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )

    output.append(token_T.clone().tolist()[0])
    input_pos = torch.tensor([T], device=device)

    for _ in tqdm(range(2, max_returned_tokens - T + 1)):
        model_input_ids = []
        for i in range(7):
            model_input_ids.append(
                torch.tensor([layershift(snac_config.end_of_audio, i)])
                .view(1, -1)
                .to(torch.int32)
                .to(device)
            )
        model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
        token_T = next_token_A1T1(
            model,
            None,
            model_input_ids,
            None,
            None,
            input_pos=input_pos,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )
        if token_T == eos_id_t:
            break
        output.append(token_T.clone().tolist()[0])
        input_pos = input_pos.add_(1)
    return output


@torch.inference_mode()
def generate_AT(
    model: GPT,
    audio_features: torch.Tensor,
    input_ids: list,
    leng,
    task,
    max_returned_tokens: int = 2048,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id_t: Optional[int] = None,
    shift: Optional[int] = None,
    include_prompt: bool = True,
    generate_text=False,
) -> torch.Tensor:

    T = input_ids[0].size(1)
    device = input_ids[0].device

    output = []
    token_T = next_token_A1T1(
        model,
        audio_features.to(torch.float32).to(model.device),
        input_ids,
        [T - 3],
        ["AT"],
        input_pos=torch.arange(0, T, device=device),
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )
    output.append(token_T.clone().tolist()[0])
    input_pos = torch.tensor([T], device=device)
    text_end = False
    for _ in tqdm(range(2, max_returned_tokens - T + 1)):
        model_input_ids = []
        for i in range(7):
            model_input_ids.append(
                torch.tensor([layershift(snac_config.end_of_audio, i)])
                .view(1, -1)
                .to(torch.int32)
                .to(device)
            )
        model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
        token_T = next_token_A1T1(
            model,
            None,
            model_input_ids,
            None,
            None,
            input_pos=input_pos,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )
        if token_T == eos_id_t:
            break
        output.append(token_T.clone().tolist()[0])
        input_pos = input_pos.add_(1)
    return output


@torch.inference_mode()
def generate_TA(
    model: GPT,
    audio_features: torch.Tensor,
    input_ids: list,
    leng,
    task,
    max_returned_tokens: int = 2048,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id_t: Optional[int] = None,
    shift: Optional[int] = None,
    include_prompt: bool = True,
    generate_text=False,
) -> torch.Tensor:

    T = input_ids[0].size(1)
    device = input_ids[0].device

    output = [[] for _ in range(8)]
    tokens_A, token_T = next_token_A1T2(
        model,
        None,
        input_ids,
        None,
        None,
        input_pos=torch.arange(0, T, device=device),
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )
    for i in range(7):
        output[i].append(tokens_A[i].clone().tolist()[0])
    output[7].append(token_T.clone().tolist()[0])

    input_pos = torch.tensor([T], device=device)
    text_end = False
    for _ in tqdm(range(2, max_returned_tokens - T + 1)):

        model_input_ids = []
        for i in range(7):
            model_input_ids.append(
                layershift(tokens_A[i].clone(), i)
                .view(1, -1)
                .to(torch.int32)
                .to(device)
            )
        model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))

        tokens_A, token_T = next_token_A1T2(
            model,
            None,
            model_input_ids,
            None,
            None,
            input_pos=input_pos,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )

        if text_end:
            token_T = torch.tensor([pad_id_t], device=device)

        if tokens_A[-1] == eos_id_a:
            break

        if token_T == eos_id_t:
            text_end = True

        for i in range(7):
            output[i].append(tokens_A[i].clone().tolist()[0])
        output[7].append(token_T.clone().tolist()[0])
        input_pos = input_pos.add_(1)

    return output


@torch.inference_mode()
def generate_AA(
    model: GPT,
    audio_features: torch.Tensor,
    input_ids: list,
    leng,
    task,
    max_returned_tokens: int = 2048,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id_t: Optional[int] = None,
    shift: Optional[int] = None,
    include_prompt: bool = True,
    generate_text=False,
) -> torch.Tensor:

    T = input_ids[0].size(1)
    device = input_ids[0].device

    output = [[] for _ in range(8)]
    tokens_A, token_T = next_token_A1T2(
        model,
        audio_features.to(torch.float32).to(model.device),
        input_ids,
        [T - 3],
        ["A1T2"],
        input_pos=torch.arange(0, T, device=device),
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )
    for i in range(7):
        output[i].append(tokens_A[i].clone().tolist()[0])
    output[7].append(token_T.clone().tolist()[0])

    input_pos = torch.tensor([T], device=device)

    text_end = False
    for _ in tqdm(range(2, max_returned_tokens - T + 1)):

        model_input_ids = []
        for i in range(7):
            model_input_ids.append(
                layershift(tokens_A[i].clone(), i)
                .view(1, -1)
                .to(torch.int32)
                .to(device)
            )
        model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))

        tokens_A, token_T = next_token_A1T2(
            model,
            None,
            model_input_ids,
            None,
            None,
            input_pos=input_pos,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )

        if text_end:
            token_T = torch.tensor([pad_id_t], device=device)

        if tokens_A[-1] == eos_id_a:
            break
        if token_T == eos_id_t:
            # print("text_end")
            text_end = True

        for i in range(7):
            output[i].append(tokens_A[i].clone().tolist()[0])
        output[7].append(token_T.clone().tolist()[0])
        input_pos = input_pos.add_(1)

    return output


@torch.inference_mode()
def generate_ASR(
    model: GPT,
    audio_features: torch.Tensor,
    input_ids: list,
    leng,
    task,
    max_returned_tokens: int = 1200,
    *,
    temperature: float = 1.0,
    top_k: Optional[int] = None,
    top_p: float = 1.0,
    eos_id_a: Optional[int] = None,
    eos_id_t: Optional[int] = None,
    pad_id_t: Optional[int] = None,
    shift: Optional[int] = None,
    include_prompt: bool = True,
    generate_text=False,
) -> torch.Tensor:

    T = input_ids[0].size(1)
    device = input_ids[0].device
    output = []
    token_T = next_token_A1T1(
        model,
        audio_features.to(torch.float32).to(model.device),
        input_ids,
        [T - 3],
        ["asr"],
        input_pos=torch.arange(0, T, device=device),
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    )
    output.append(token_T.clone().tolist()[0])
    input_pos = torch.tensor([T], device=device)
    text_end = False
    for _ in tqdm(range(2, max_returned_tokens - T + 1)):
        model_input_ids = []
        for i in range(7):
            model_input_ids.append(
                torch.tensor([layershift(snac_config.end_of_audio, i)])
                .view(1, -1)
                .to(torch.int32)
                .to(device)
            )
        model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
        token_T = next_token_A1T1(
            model,
            None,
            model_input_ids,
            None,
            None,
            input_pos=input_pos,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )
        if token_T == eos_id_t:
            break
        output.append(token_T.clone().tolist()[0])
        input_pos = input_pos.add_(1)
    return output