try:
    from extensions.telegram_bot.source.generators.abstract_generator import AbstractGenerator
except ImportError:
    from source.generators.abstract_generator import AbstractGenerator
import os, glob, sys
import torch
from typing import List

sys.path.append(os.path.join(os.path.split(__file__)[0], "exllama"))
from source.generators.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
from source.generators.exllama.tokenizer import ExLlamaTokenizer
from source.generators.exllama.generator import ExLlamaGenerator


class Generator(AbstractGenerator):
    #  Place where path to LLM file stored
    model_change_allowed = False  # if model changing allowed without stopping.
    preset_change_allowed = True  # if preset_file changing allowed.

    def __init__(self, model_path: str, n_ctx=4096, seed=0, n_gpu_layers=0):
        self.n_ctx = n_ctx
        self.seed = seed
        self.n_gpu_layers = n_gpu_layers
        self.model_directory = model_path

        # Locate files we need within that directory
        self.tokenizer_path = os.path.join(self.model_directory, "tokenizer.model")
        self.model_config_path = os.path.join(self.model_directory, "config.json")
        self.st_pattern = os.path.join(self.model_directory, "model.safetensors")
        self.model_path = glob.glob(self.st_pattern)
        # Create config, model, tokenizer and generator

        self.ex_config = ExLlamaConfig(self.model_config_path)  # create config from config.json
        self.ex_config.llm_path = self.model_path  # supply path to model weights file
        self.ex_config.max_seq_len = n_ctx
        self.ex_config.max_input_len = n_ctx
        self.ex_config.max_attention_size = n_ctx**2
        self.model = ExLlama(self.ex_config)  # create ExLlama instance and load the weights

        self.tokenizer = ExLlamaTokenizer(self.tokenizer_path)  # create tokenizer from tokenizer model file

        self.cache = ExLlamaCache(self.model, max_seq_len=n_ctx)  # create cache for inference
        self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)  # create generator

    def generate_answer(
        self, prompt, generation_params, eos_token, stopping_strings, default_answer: str, turn_template="", **kwargs
    ):
        # Preparing, add stopping_strings
        answer = default_answer

        try:
            # Configure generator
            self.generator.disallow_tokens([self.tokenizer.eos_token_id])
            self.generator.settings.token_repetition_penalty_max = generation_params["repetition_penalty"]
            self.generator.settings.temperature = generation_params["temperature"]
            self.generator.settings.top_p = generation_params["top_p"]
            self.generator.settings.top_k = generation_params["top_k"]
            self.generator.settings.typical = generation_params["typical_p"]
            # random seed set
            random_data = os.urandom(4)
            random_seed = int.from_bytes(random_data, byteorder="big")
            torch.manual_seed(random_seed)
            torch.cuda.manual_seed(random_seed)
            # Produce a simple generation
            answer = self.generate_custom(
                prompt, stopping_strings=stopping_strings, max_new_tokens=generation_params["max_new_tokens"]
            )
            answer = answer[len(prompt) :]
        except Exception as exception:
            print("generator_wrapper get answer error ", str(exception) + str(exception.args))
        return answer

    def generate_custom(self, prompt, stopping_strings: List, max_new_tokens=128):
        self.generator.end_beam_search()

        ids, mask = self.tokenizer.encode(prompt, return_mask=True, max_seq_len=self.model.config.max_seq_len)
        self.generator.gen_begin(ids, mask=mask)

        max_new_tokens = min(max_new_tokens, self.model.config.max_seq_len - ids.shape[1])

        eos = torch.zeros((ids.shape[0],), dtype=torch.bool)
        for i in range(max_new_tokens):
            token = self.generator.gen_single_token(mask=mask)
            for j in range(token.shape[0]):
                if token[j, 0].item() == self.tokenizer.eos_token_id:
                    eos[j] = True
            text = self.tokenizer.decode(
                self.generator.sequence[0] if self.generator.sequence.shape[0] == 1 else self.generator.sequence
            )
            # check stopping string
            for stopping in stopping_strings:
                if text.endswith(stopping):
                    text = text[: -len(stopping)]
                    return text
            if eos.all():
                break

        text = self.tokenizer.decode(
            self.generator.sequence[0] if self.generator.sequence.shape[0] == 1 else self.generator.sequence
        )
        return text

    def tokens_count(self, text: str):
        encoded = self.tokenizer.encode(text, max_seq_len=20480)
        return len(encoded[0])

    def get_model_list(self):
        bins = []
        for i in os.listdir("../../models"):
            if i.endswith(".bin"):
                bins.append(i)
        return bins

    def load_model(self, model_file: str):
        return None