# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass

import torch
from transformers import (
    LogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation.streamers import BaseStreamer

from chameleon.inference.alignment import AlignPromptLeft, PromptAlignment
from chameleon.inference.model_adapter import ModelAdapter
from chameleon.inference.stopping_criteria import StoppingCriteria, StoppingCriteriaList
from chameleon.inference.token_selector import MultinomialTokenSelector, TokenSelector


class ChameleonGenerator:
    @dataclass
    class Token:
        id: torch.LongTensor
        logits: torch.Tensor | None

    def __init__(
        self,
        model: ModelAdapter,
        input_ids: list[list[int]],
        stopping_criteria: StoppingCriteriaList | list[StoppingCriteria] | None = None,
        logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
        probability_processors: LogitsProcessorList
        | list[LogitsProcessor]
        | None = None,
        token_selector: TokenSelector | None = None,
        alignment: PromptAlignment = AlignPromptLeft(),
    ):
        assert model.supports_alignment(alignment)

        self.model = model

        self.stopping_criteria = stopping_criteria
        self.logits_processors = logits_processors
        self.probability_processors = probability_processors
        self.token_selector: TokenSelector = (
            token_selector or MultinomialTokenSelector()
        )

        self.alignment = alignment

        self.model.initialize(input_ids)

        self._inputs = self.alignment.prepare_inputs(
            input_ids
        )  # inputs.shape = [batch, seq-len]

        self._idx = 0
        self._start_idx = self.alignment.start_index(input_ids)

        self._original_inputs = self._inputs.clone()
        self._inputs = self._inputs[:, : self._start_idx]

    def __iter__(self):
        return self

    @torch.inference_mode()
    def __next__(self) -> Token:
        # Are we done?
        if self.stopping_criteria(self._inputs, None):
            raise StopIteration

        # Emit initial tokens.
        # Model is not run for these.
        # If you want the logits, you can do a separate forward pass outside generation.
        if self._idx < self._start_idx:
            idx, self._idx = self._idx, self._idx + 1
            return ChameleonGenerator.Token(id=self._inputs[:, idx], logits=None)

        # Run the model for the next token.
        self._inputs = self._inputs.contiguous()
        outputs = self.model(self._inputs)  # outputs.shape = [batch, seq-len, vocab]

        # Pull out and process the logits.
        logits = outputs[:, -1, :]  # logits.shape = [batch, vocab]
        logits = self.logits_processors(self._inputs, logits)
        probs = logits.softmax(dim=1)  # probs.shape = [batch, vocab]
        probs = self.probability_processors(self._inputs, probs)

        # Select a token and add it to the inputs.
        next_tokens = self.token_selector(
            self._inputs, probs
        )  # next_tokens.shape = [batch]
        self._inputs = torch.cat([self._inputs, next_tokens[:, None]], dim=1)

        # Run alignment specific postprocessing.
        self._inputs = self.alignment.postprocess_inputs(
            self._inputs, self._original_inputs
        )

        # Return the next step result.
        return ChameleonGenerator.Token(id=self._inputs[:, -1], logits=logits)

    @property
    def stopping_criteria(self) -> StoppingCriteriaList:
        return self._stopping_criteria

    @stopping_criteria.setter
    def stopping_criteria(
        self, value: StoppingCriteriaList | list[StoppingCriteria] | None
    ):
        self._stopping_criteria = StoppingCriteriaList(value or [])

    @property
    def logits_processors(self) -> LogitsProcessorList:
        return self._logits_processors

    @logits_processors.setter
    def logits_processors(
        self, value: LogitsProcessorList | list[LogitsProcessor] | None
    ):
        self._logits_processors = LogitsProcessorList(value or [])

    @property
    def probability_processors(self) -> LogitsProcessorList:
        return self._probability_processors

    @probability_processors.setter
    def probability_processors(
        self, value: LogitsProcessorList | list[LogitsProcessor] | None
    ):
        self._probability_processors = LogitsProcessorList(value or [])


def run_generation(
    model: torch.nn.Module,
    input_ids: list[list[int]],
    stopping_criteria: StoppingCriteriaList | list[StoppingCriteria],
    logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
    probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
    token_selector: TokenSelector | None = None,
    alignment: PromptAlignment = AlignPromptLeft(),
    streamer: BaseStreamer | None = None,
) -> torch.LongTensor:
    result = torch.empty((len(input_ids), 0), dtype=int)
    for tok in ChameleonGenerator(
        model=model,
        input_ids=input_ids,
        stopping_criteria=stopping_criteria,
        logits_processors=logits_processors,
        probability_processors=probability_processors,
        token_selector=token_selector,
        alignment=alignment,
    ):
        if streamer is not None:
            streamer.put(tok.id)
        result = torch.cat([result, tok.id.view(-1, 1)], dim=1)

    if streamer is not None:
        streamer.end()

    return result