import logging
import os
from abc import ABC, abstractmethod

import tiktoken
from gradio_client import Client


from openai import OpenAI

client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

HUB_TOKEN = os.getenv("HUB_TOKEN")
FLAN_HF_SPACE = "jerpint/i-like-flan"
LLAMA_HF_SPACE = "meta-llama/Llama-2-7b"


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class TextCompletion(ABC):
    @abstractmethod
    def get_token_count(prompt: str) -> int:
        ...

    @abstractmethod
    def complete(prompt: str):
        ...

    def get_score_multiplier(self) -> float:
        return 1.0


class DebugCompleter(TextCompletion):
    """This is only meant for testing/debugging purposes."""

    def __init__(self, model):
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model("text-davinci-003")
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        if "$CORRECT$" in prompt:
            # returns the correct prompt
            return "I have been PWNED"
        elif "$ERROR$" in prompt:
            # simulate an error
            raise RuntimeError(f"User raised an error explicitly in {self.model}")
        else:
            return "Some incorrect completion"


class ChatGPTCompletor(TextCompletion):
    def __init__(self, model):
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model(self.model)
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        messages = [
            {"role": "user", "content": prompt},
        ]

        completion = client.chat.completions.create(
            messages=messages,
            temperature=0,
            model=self.model,
        )

        # Get the response text
        response_text = completion.choices[0].message.content

        return response_text

    def get_score_multiplier(self) -> float:
        return 2.0


class FlanT5Completor(TextCompletion):
    def __init__(self, model: str):

        self.model = model
        self.client = Client(src=FLAN_HF_SPACE, hf_token=HUB_TOKEN)

    def get_token_count(self, prompt: str) -> int:
        from transformers import T5Tokenizer
        tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        token_count = len(input_ids[0])
        return token_count

    def complete(self, prompt: str) -> str:
        response_text = self.client.predict(prompt)
        return response_text


class LLama2Completor(TextCompletion):
    def __init__(self, model: str):
        self.model = model
        # self.client = Client(src=LLAMA_HF_SPACE, hf_token=HUB_TOKEN)
        self.client = Client(src="https://e554306b558a76b0fd.gradio.live")

    def get_token_count(self, prompt: str) -> int:
        return 0

    def complete(self, prompt: str) -> str:
        output = self.client.predict(prompt)
        return output


class ClaudeCompletor(TextCompletion):
    def __init__(self, model="claude-2"):
        from anthropic import Anthropic
        import os

        api_key = os.environ["ANTHROPIC_API_KEY"]

        self.model = model
        self.anthropic = Anthropic(api_key=api_key)

    def get_token_count(self, prompt: str):
        "Not used in actual competition, so just dummy value"
        return 0

    def complete(self, prompt: str):
        from anthropic import HUMAN_PROMPT, AI_PROMPT

        completion = self.anthropic.completions.create(
            model=self.model,
            max_tokens_to_sample=300,
            prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
        )
        return completion.completion


completers = {
    "gpt-3.5-turbo": ChatGPTCompletor,
    "gpt-4o": ChatGPTCompletor,
    "gpt-4": ChatGPTCompletor,
    "gpt-3.5-turbo-0125": ChatGPTCompletor,
    "gpt-4-0613": ChatGPTCompletor,
    "gpt-4o-2024-05-13": ChatGPTCompletor,
}


def get_completer(model: str):
    logger.info(f"Loading completer for {model=}")

    if model in completers.keys():
        completer = completers[model](model=model)
    elif model == "gpt-debug":
        # Not included in completers because we dont want it showing in the app
        logger.warning("Using debugging completer...")
        completer = DebugCompleter(model=model)
    else:
        raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")

    return completer