# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, Callable, Dict, Optional, Sequence
from llama_index.bridge.pydantic import Field, PrivateAttr
from llama_index.callbacks import CallbackManager
from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.llms.base import (
    ChatMessage,
    ChatResponse,
    CompletionResponse,
    ChatResponseGen,
    CompletionResponseGen,
    LLMMetadata,
    llm_chat_callback,
    llm_completion_callback,
)
from llama_index.llms.custom import CustomLLM
from llama_index.llms.generic_utils import stream_completion_response_to_chat_response
from llama_index.llms.generic_utils import completion_response_to_chat_response
from llama_index.llms.generic_utils import (
    messages_to_prompt as generic_messages_to_prompt,
)
from utils import (DEFAULT_HF_MODEL_DIRS, DEFAULT_PROMPT_TEMPLATES,
                   load_tokenizer, read_model_name, throttle_generator)
import gc
import torch
import tensorrt_llm
import uuid
import time
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
from tensorrt_llm.logger import logger
EOS_TOKEN = 2
PAD_TOKEN = 2

class TrtLlmAPI(CustomLLM):
    model_path: Optional[str] = Field(
        description="The path to the trt engine."
    )
    temperature: float = Field(description="The temperature to use for sampling.")
    max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
    context_window: int = Field(
        description="The maximum number of context tokens for the model."
    )
    messages_to_prompt: Callable = Field(
        description="The function to convert messages to a prompt.", exclude=True
    )
    completion_to_prompt: Callable = Field(
        description="The function to convert a completion to a prompt.", exclude=True
    )
    generate_kwargs: Dict[str, Any] = Field(
        default_factory=dict, description="Kwargs used for generation."
    )
    model_kwargs: Dict[str, Any] = Field(
        default_factory=dict, description="Kwargs used for model initialization."
    )
    verbose: bool = Field(description="Whether to print verbose output.")

    _model: Any = PrivateAttr()
    _model_config: Any = PrivateAttr()
    _tokenizer: Any = PrivateAttr()
    _pad_id:Any = PrivateAttr()
    _end_id: Any = PrivateAttr()
    _new_max_token: Any = PrivateAttr()
    _max_new_tokens = PrivateAttr()
    _sampling_config = PrivateAttr()
    _verbose = PrivateAttr()

    def __init__(
            self,
            model_path: Optional[str] = None,
            engine_name: Optional[str] = None,
            tokenizer_dir: Optional[str] = None,
            temperature: float = 0.1,
            max_new_tokens: int = DEFAULT_NUM_OUTPUTS,
            context_window: int = DEFAULT_CONTEXT_WINDOW,
            messages_to_prompt: Optional[Callable] = None,
            completion_to_prompt: Optional[Callable] = None,
            callback_manager: Optional[CallbackManager] = None,
            generate_kwargs: Optional[Dict[str, Any]] = None,
            model_kwargs: Optional[Dict[str, Any]] = None,
            verbose: bool = False
    ) -> None:

        model_kwargs = model_kwargs or {}
        model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
        #logger.set_level('verbose')
        runtime_rank = tensorrt_llm.mpi_rank()
        model_name = read_model_name(model_path)

        self._tokenizer, self._pad_id, self._end_id = load_tokenizer(
            tokenizer_dir=tokenizer_dir,
            #vocab_file=args.vocab_file,
            model_name=model_name,
            #tokenizer_type=args.tokenizer_type,
        )
        stop_words_list = None
        bad_words_list = None
        runner_cls = ModelRunner
        runner_kwargs = dict(engine_dir=model_path,
                             #lora_dir=args.lora_dir,
                             rank=runtime_rank,
                             debug_mode=True,
                             lora_ckpt_source='hf')
        self._model = runner_cls.from_dir(**runner_kwargs)
        messages_to_prompt = messages_to_prompt or generic_messages_to_prompt
        completion_to_prompt = completion_to_prompt or (lambda x: x)

        generate_kwargs = generate_kwargs or {}
        generate_kwargs.update(
            {"temperature": temperature, "max_tokens": max_new_tokens}
        )
        #self._tokenizer = LlamaTokenizer.from_pretrained(tokenizer_dir, legacy=False)
        self._new_max_token = max_new_tokens

        super().__init__(
            model_path=model_path,
            temperature=temperature,
            context_window=context_window,
            max_new_tokens=max_new_tokens,
            messages_to_prompt=messages_to_prompt,
            completion_to_prompt=completion_to_prompt,
            callback_manager=callback_manager,
            generate_kwargs=generate_kwargs,
            model_kwargs=model_kwargs,
            verbose=verbose,
        )

    @classmethod
    def class_name(cls) -> str:
        """Get class name."""
        return "TrtLlmAPI"

    @property
    def metadata(self) -> LLMMetadata:
        """LLM metadata."""
        return LLMMetadata(
            context_window=self.context_window,
            num_output=self.max_new_tokens,
            model_name=self.model_path,
        )

    @llm_chat_callback()
    def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
        prompt = self.messages_to_prompt(messages)
        completion_response = self.complete(prompt, formatted=True, **kwargs)
        return completion_response_to_chat_response(completion_response)

    @llm_chat_callback()
    def stream_chat(
        self, messages: Sequence[ChatMessage], **kwargs: Any
    ) -> ChatResponseGen:
        prompt = self.messages_to_prompt(messages)
        completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
        return stream_completion_response_to_chat_response(completion_response)

    @llm_completion_callback()
    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        self.generate_kwargs.update({"stream": False})
        is_formatted = kwargs.pop("formatted", False)
        if not is_formatted:
            prompt = self.completion_to_prompt(prompt)
        input_text = [prompt]

        batch_input_ids = self.parse_input(self._tokenizer,
                                      input_text,
                                      pad_id=self._pad_id,
                                      )
        input_lengths = [x.size(1) for x in batch_input_ids]

        with torch.no_grad():
            outputs = self._model.generate(
                batch_input_ids,
                max_new_tokens=self._new_max_token,
                end_id=self._end_id,
                pad_id=self._pad_id,
                temperature=1.0,
                top_k=1,
                top_p=0,
                num_beams=1,
                length_penalty=1.0,
                repetition_penalty=1.0,
                stop_words_list=None,
                bad_words_list=None,
                lora_uids=None,
                prompt_table_path=None,
                prompt_tasks=None,
                streaming=False,
                output_sequence_lengths=True,
                return_dict=True)
            torch.cuda.synchronize()

        output_ids = outputs['output_ids']
        sequence_lengths = outputs['sequence_lengths']
        output_txt, output_token_ids = self.print_output(self._tokenizer,
                                                        output_ids,
                                                        input_lengths,
                                                        sequence_lengths)
        # call garbage collected after inference
        torch.cuda.empty_cache()
        gc.collect()
        return CompletionResponse(text=output_txt, raw=self.generate_completion_dict(output_txt))

    def parse_input(self,
                    tokenizer,
                    input_text=None,
                    prompt_template=None,
                    input_file=None,
                    add_special_tokens=True,
                    max_input_length=4096,
                    pad_id=None,
                    num_prepend_vtokens=[]):
        if pad_id is None:
            pad_id = tokenizer.pad_token_id

        batch_input_ids = []
        for curr_text in input_text:
            if prompt_template is not None:
                curr_text = prompt_template.format(input_text=curr_text)
            input_ids = tokenizer.encode(curr_text,
                                         add_special_tokens=add_special_tokens,
                                         truncation=True,
                                         max_length=max_input_length)
            batch_input_ids.append(input_ids)

        if num_prepend_vtokens:
            assert len(num_prepend_vtokens) == len(batch_input_ids)
            base_vocab_size = tokenizer.vocab_size - len(
                tokenizer.special_tokens_map.get('additional_special_tokens', []))
            for i, length in enumerate(num_prepend_vtokens):
                batch_input_ids[i] = list(
                    range(base_vocab_size,
                          base_vocab_size + length)) + batch_input_ids[i]

        batch_input_ids = [
            torch.tensor(x, dtype=torch.int32).unsqueeze(0) for x in batch_input_ids
        ]
        return batch_input_ids

    def remove_extra_eos_ids(self, outputs):
        outputs.reverse()
        while outputs and outputs[0] == 2:
            outputs.pop(0)
        outputs.reverse()
        outputs.append(2)
        return outputs

    def print_output(self,
                     tokenizer,
                     output_ids,
                     input_lengths,
                     sequence_lengths,
                     output_csv=None,
                     output_npy=None,
                     context_logits=None,
                     generation_logits=None,
                     output_logits_npy=None):
        output_text = ""
        batch_size, num_beams, _ = output_ids.size()
        if output_csv is None and output_npy is None:
            for batch_idx in range(batch_size):
                inputs = output_ids[batch_idx][0][:input_lengths[batch_idx]].tolist(
                )
                for beam in range(num_beams):
                    output_begin = input_lengths[batch_idx]
                    output_end = sequence_lengths[batch_idx][beam]
                    outputs = output_ids[batch_idx][beam][
                              output_begin:output_end].tolist()
                    output_text = tokenizer.decode(outputs)

        output_ids = output_ids.reshape((-1, output_ids.size(2)))
        return output_text, output_ids

    def get_output(self, output_ids, input_lengths, max_output_len, tokenizer):
        num_beams = 1
        output_text = ""
        outputs = None
        for b in range(input_lengths.size(0)):
            for beam in range(num_beams):
                output_begin = input_lengths[b]
                output_end = input_lengths[b] + max_output_len
                outputs = output_ids[b][beam][output_begin:output_end].tolist()
                outputs = self.remove_extra_eos_ids(outputs)
                output_text = tokenizer.decode(outputs)

        return output_text, outputs

    def generate_completion_dict(self, text_str):
        """
        Generate a dictionary for text completion details.
        Returns:
        dict: A dictionary containing completion details.
        """
        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
        created: int = int(time.time())
        model_name: str = self._model if self._model is not None else self.model_path
        return {
            "id": completion_id,
            "object": "text_completion",
            "created": created,
            "model": model_name,
            "choices": [
                {
                    "text": text_str,
                    "index": 0,
                    "logprobs": None,
                    "finish_reason": 'stop'
                }
            ],
            "usage": {
                "prompt_tokens": None,
                "completion_tokens": None,
                "total_tokens": None
            }
        }

    @llm_completion_callback()
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        is_formatted = kwargs.pop("formatted", False)
        if not is_formatted:
            prompt = self.completion_to_prompt(prompt)
        input_text = [prompt]
        batch_input_ids = self.parse_input(self._tokenizer,
                                      input_text,
                                      pad_id=self._end_id,
                                      )
        input_lengths = [x.size(1) for x in batch_input_ids]
        with torch.no_grad():
            outputs = self._model.generate(
                batch_input_ids,
                max_new_tokens=self._new_max_token,
                end_id=self._end_id,
                pad_id=self._pad_id,
                temperature=1.0,
                top_k=1,
                top_p=0,
                num_beams=1,
                length_penalty=1.0,
                repetition_penalty=1.0,
                stop_words_list=None,
                bad_words_list=None,
                lora_uids=None,
                prompt_table_path=None,
                prompt_tasks=None,
                streaming=True,
                output_sequence_lengths=True,
                return_dict=True)
            torch.cuda.synchronize()
        previous_text = ""  # To keep track of the previously yielded text

        def gen() -> CompletionResponseGen:
            nonlocal previous_text  # Declare previous_text as nonlocal
            for curr_outputs in throttle_generator(outputs,
                                                   5):
                output_ids = curr_outputs['output_ids']
                sequence_lengths = curr_outputs['sequence_lengths']
                output_txt, output_token_ids = self.print_output(self._tokenizer,
                                                                 output_ids,
                                                                 input_lengths,
                                                                 sequence_lengths)
                if output_txt.endswith("</s>"):
                    output_txt = output_txt[:-4]
                pre_token_len = len(previous_text)
                new_text = output_txt[pre_token_len:]  # Get only the new text
                yield CompletionResponse(delta=new_text, text=output_txt,
                                         raw=self.generate_completion_dict(output_txt))
                previous_text = output_txt  # Update the previously yielded text after yielding
        return gen()

    def unload_model(self):
        if self._model is not None:
            del self._model
        # Step 3: Additional cleanup if needed
        torch.cuda.empty_cache()
        gc.collect()