# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.

import torch


class TokenSelector:
    def __call__(
        self, input_ids: torch.LongTensor, probs: torch.FloatTensor
    ) -> torch.FloatTensor:
        # input_ids.shape=[batch, seq_len]
        # probs.shape=[batch, vocab]
        ...


class ArgmaxTokenSelector(TokenSelector):
    def __call__(
        self, _: torch.LongTensor, probs: torch.FloatTensor
    ) -> torch.LongTensor:
        # probs.shape=[batch, vocab]
        return probs.argmax(dim=1)


class MultinomialTokenSelector(TokenSelector):
    def __call__(
        self, _: torch.LongTensor, probs: torch.FloatTensor
    ) -> torch.LongTensor:
        # probs.shape=[batch, vocab]
        return probs.multinomial(num_samples=1).squeeze(1)


class ReplicatedInputTokenSelector(TokenSelector):
    def __init__(self, token_selector: TokenSelector, n: int):
        self.token_selector = token_selector
        self.n = n

    def __call__(
        self, input_ids: torch.LongTensor, probs: torch.FloatTensor
    ) -> torch.LongTensor:
        # input_ids.shape=[n*batch, seq_len]
        # probs.shape=[n*batch, vocab]
        primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0]
        primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0]
        tokens = self.token_selector(primary_input_ids, primary_probs)
        return tokens.repeat(self.n)