import os
from dataclasses import dataclass
import re


@dataclass(unsafe_hash=True)
class Model(object):
    model_display_name: str
    model_name: str
    api_url: str
    provider: str
    hourly_cost: int = None
    cost_description: str = None
    supports_functions: str = False
    size_billion_parameters: int = None
    cost_per_million_tokens: int = None
    cost_per_million_input_tokens: int = None
    cost_per_million_output_tokens: int = None
    context_length: int = None
    selected: bool = False

    def __post_init__(self):
        self.cost_per_million_input_tokens = self.cost_per_million_input_tokens or self.cost_per_million_tokens
        self.cost_per_million_output_tokens = self.cost_per_million_output_tokens or self.cost_per_million_tokens

    @property
    def cost(self):
        if self.cost_description:
            return self.cost_description
        if self.hourly_cost:
            return f"${self.hourly_cost:.2g} / hour"
        if self.cost_per_million_tokens:
            return f"${self.cost_per_million_tokens:.2g} / 1M tokens"
        elif self.cost_per_million_input_tokens and self.cost_per_million_output_tokens:
            return f"${self.cost_per_million_input_tokens:.2g} / 1M input tokens, ${self.cost_per_million_output_tokens:.2g} / 1M output tokens"


env = os.environ.get

MODELS = [
    # source: https://openai.com/pricing
    # converted costs from dollar/1K tokens to dollar/1M for readability and together_ai comparability
    Model(
        "gpt-3.5-turbo",
        "gpt-3.5-turbo",
        None,
        "OpenAI",
        supports_functions=True,
        cost_per_million_input_tokens=1,
        cost_per_million_output_tokens=2,
        # https://learn.microsoft.com/en-us/answers/questions/1356487/what-is-the-exact-maximum-input-tokens-of-azure-gp
        context_length=4096,
    ),
    Model(
        "gpt-4-turbo",
        "gpt-4-1106-preview",
        None,
        "OpenAI",
        supports_functions=True,
        cost_per_million_input_tokens=10,
        cost_per_million_output_tokens=30,
        # https://writesonic.com/blog/gpt-4-turbo-vs-gpt-4
        context_length=128_000,
    ),
    Model(
        "gpt-4",
        "gpt-4",
        None,
        "OpenAI",
        supports_functions=True,
        cost_per_million_input_tokens=30,
        cost_per_million_output_tokens=60,
        context_length=32_000,
    ),
    # source: https://www.together.ai/pricing
    Model(
        "llama-2-70b-chat",
        "together_ai/togethercomputer/llama-2-70b-chat",
        None,
        "Together AI",
        cost_per_million_tokens=0.9,
        size_billion_parameters=70,
        # https://github.com/facebookresearch/llama/issues/148
        context_length=2048,
    ),
    Model(
        "Mixtral-8x7B-Instruct-v0.1",
        "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1",
        None,
        "Together AI",
        size_billion_parameters=8 * 7,
        cost_per_million_tokens=0.9,
    ),
    # taken from endpoint pages
    Model(
        "zephyr-7b-beta",
        "huggingface/HuggingFaceH4/zephyr-7b-beta",
        env("ZEPHYR_7B_BETA_URL"),
        "Hugging Face Inference Endpoint",
        hourly_cost=1.30,
        size_billion_parameters=7,
    ),
    Model(
        "Mistral-7B-Instruct-v0.2",
        "huggingface/mistralai/Mistral-7B-Instruct-v0.2",
        env("MISTRAL_7B_BETA_URL"),
        "Hugging Face Inference Endpoint",
        hourly_cost=1.30,
        size_billion_parameters=7,
    ),
    Model(
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "huggingface/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        env("TINY_LLAMA_URL"),
        "Hugging Face Inference Endpoint",
        hourly_cost=0.60,
        size_billion_parameters=1.1,
    ),
    Model(
        "gemini-pro",
        "gemini-pro",
        None,
        "Google VertexAI",
        # https://ai.google.dev/pricing
        cost_description="$0.25 / 1M input characters, $0.5 / 1M output characters (60 queries per minute are free)",
        cost_per_million_input_tokens=0.25,
        cost_per_million_output_tokens=0.5,
    ),
    Model(
        "chat-bison (PaLM 2)",
        "chat-bison",
        None,
        "Google VertexAI",
        # https://cloud.google.com/vertex-ai/docs/generative-ai/pricing
        cost_per_million_input_tokens=0.25,
        cost_per_million_output_tokens=0.5,
        # https://ai.google.dev/models/palm
        context_length=8196,
    ),
    Model(
        "chat-bison-32k (PaLM 2 32K)",
        "chat-bison-32k",
        None,
        "Google VertexAI",
        # https://cloud.google.com/vertex-ai/docs/generative-ai/pricing
        cost_per_million_input_tokens=0.25,
        cost_per_million_output_tokens=0.5,
    ),
    # new models:
    Model(
        "01-ai Yi Chat (34B)",
        "together_ai/zero-one-ai/Yi-34B-Chat",
        None,
        "Together AI",
        context_length=4096,
        # selected=True,
    ),
    Model(
        "Chronos Hermes (13B)",
        "together_ai/Austism/chronos-hermes-13b",
        None,
        "Together AI",
        context_length=2048,
        # selected=True,
    ),
    Model(
        "Deepseek Coder Instruct (33B)",
        "together_ai/deepseek-ai/deepseek-coder-33b-instruct",
        None,
        "Together AI",
        context_length=16384,
    ),
    Model(
        "Platypus2 Instruct (70B)",
        "together_ai/garage-bAInd/Platypus2-70B-instruct",
        None,
        "Together AI",
        context_length=4096,
    ),
    Model(
        "MythoMax-L2 (13B)",
        "together_ai/Gryphe/MythoMax-L2-13b",
        None,
        "Together AI",
        context_length=4096,
    ),
    Model(
        "Vicuna v1.5 (13B)",
        "together_ai/lmsys/vicuna-13b-v1.5",
        None,
        "Together AI",
        context_length=4096,
    ),
    Model(
        "Vicuna v1.5 (7B)",
        "together_ai/lmsys/vicuna-7b-v1.5",
        None,
        "Together AI",
        context_length=4096,
        size_billion_parameters=7,
        # selected=True
    ),
    Model(
        "Code Llama Instruct (13B)",
        "together_ai/codellama/CodeLlama-13b-Instruct-hf",
        None,
        "Together AI",
        context_length=16384,
    ),
    Model(
        "Code Llama Instruct (34B)",
        "together_ai/codellama/CodeLlama-34b-Instruct-hf",
        None,
        "Together AI",
        context_length=16384,
    ),
    Model(
        "Code Llama Instruct (70B)",
        "together_ai/codellama/CodeLlama-70b-Instruct-hf",
        None,
        "Together AI",
        context_length=4096,
    ),
    Model(
        "Code Llama Instruct (7B)",
        "together_ai/codellama/CodeLlama-7b-Instruct-hf",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=16384,
    ),
    Model(
        "LLaMA-2 Chat (13B)",
        "together_ai/togethercomputer/llama-2-13b-chat",
        None,
        "Together AI",
        size_billion_parameters=13,
        context_length=4096,
    ),
    Model(
        "LLaMA-2 Chat (70B)",
        "together_ai/togethercomputer/llama-2-70b-chat",
        None,
        "Together AI",
        size_billion_parameters=70,
        context_length=4096,
    ),
    Model(
        "LLaMA-2 Chat (7B)",
        "together_ai/togethercomputer/llama-2-7b-chat",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=4096,
        # selected=True,
    ),
    Model(
        "Mistral (7B) Instruct",
        "together_ai/mistralai/Mistral-7B-Instruct-v0.1",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=4096,
    ),
    Model(
        "Mistral (7B) Instruct v0.2 (Together AI)",
        "together_ai/mistralai/Mistral-7B-Instruct-v0.2",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=32768,
        selected=True,
    ),
    Model(
        "Mixtral-8x7B Instruct (46.7B)",
        "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1",
        None,
        "Together AI",
        size_billion_parameters=8 * 7,
        context_length=32768,
    ),
    Model(
        "Nous Capybara v1.9 (7B)",
        "together_ai/NousResearch/Nous-Capybara-7B-V1p9",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=8192,
    ),
    Model(
        "Nous Hermes 2 - Mixtral 8x7B-DPO (46.7B)",
        "together_ai/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
        None,
        "Together AI",
        size_billion_parameters=8 * 7,
        context_length=32768,
    ),
    Model(
        "Nous Hermes 2 - Mixtral 8x7B-SFT (46.7B)",
        "together_ai/NousResearch/Nous-Hermes-2-Mixtral-8x7B-SFT",
        None,
        "Together AI",
        size_billion_parameters=8 * 7,
        context_length=32768,
    ),
    Model(
        "Nous Hermes LLaMA-2 (7B)",
        "together_ai/NousResearch/Nous-Hermes-llama-2-7b",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=4096,
    ),
    Model(
        "Nous Hermes Llama-2 (13B)",
        "together_ai/NousResearch/Nous-Hermes-Llama2-13b",
        None,
        "Together AI",
        size_billion_parameters=13,
        context_length=4096,
    ),
    Model(
        "Nous Hermes-2 Yi (34B)",
        "together_ai/NousResearch/Nous-Hermes-2-Yi-34B",
        None,
        "Together AI",
        size_billion_parameters=34,
        context_length=4096,
    ),
    Model(
        "OpenChat 3.5 (7B)",
        "together_ai/openchat/openchat-3.5-1210",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=8192,
    ),
    Model(
        "OpenOrca Mistral (7B) 8K",
        "together_ai/Open-Orca/Mistral-7B-OpenOrca",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=8192,
    ),
    Model(
        "Qwen-Chat (7B)",
        "together_ai/togethercomputer/Qwen-7B-Chat",
        None,
        "Together AI",
        size_billion_parameters=7,
        context_length=8192,
    ),
    Model(
        "Qwen 1.5 Chat (0.5B)",
        "together_ai/Qwen/Qwen1.5-0.5B-Chat",
        None,
        "Together AI",
        size_billion_parameters=0.5,
        context_length=32768,
    ),
    Model(
        "Qwen 1.5 Chat (1.8B)",
        "together_ai/Qwen/Qwen1.5-1.8B-Chat",
        None,
        "Together AI",
        context_length=32768,
        size_billion_parameters=1.8,
    ),
    Model(
        "Qwen 1.5 Chat (4B)",
        "together_ai/Qwen/Qwen1.5-4B-Chat",
        None,
        "Together AI",
        size_billion_parameters=4,
        context_length=32768,
    ),
    Model(
        "Qwen 1.5 Chat (7B)",
        "together_ai/Qwen/Qwen1.5-7B-Chat",
        None,
        "Together AI",
        context_length=32768,
        size_billion_parameters=7
        # selected=True
    ),
    Model(
        "Qwen 1.5 Chat (14B)",
        "together_ai/Qwen/Qwen1.5-14B-Chat",
        None,
        "Together AI",
        size_billion_parameters=14,
        context_length=32768,
    ),
    Model(
        "Qwen 1.5 Chat (72B)",
        "together_ai/Qwen/Qwen1.5-72B-Chat",
        None,
        "Together AI",
        context_length=4096,
    ),
    Model(
        "Snorkel Mistral PairRM DPO (7B)",
        "together_ai/snorkelai/Snorkel-Mistral-PairRM-DPO",
        None,
        "Together AI",
        context_length=32768,
        # selected=True,
    ),
    Model(
        "Alpaca (7B)",
        "together_ai/togethercomputer/alpaca-7b",
        None,
        "Together AI",
        context_length=2048,
    ),
    Model(
        "OpenHermes-2-Mistral (7B)",
        "teknium/OpenHermes-2-Mistral-7B",
        None,
        "Together AI",
        context_length=8192,
    ),
    Model(
        "OpenHermes-2.5-Mistral (7B)",
        "together_ai/teknium/OpenHermes-2p5-Mistral-7B",
        None,
        "Together AI",
        context_length=8192,
        # selected=True,
    ),
    Model(
        "Falcon Instruct (40B)",
        "together_ai/togethercomputer/falcon-40b-instruct",
        None,
        "Together AI",
        context_length=2048,
    ),
    Model(
        "Falcon Instruct (7B)",
        "together_ai/togethercomputer/falcon-7b-instruct",
        None,
        "Together AI",
        context_length=2048,
        # selected=True,
    ),
    Model(
        "LLaMA-2-7B-32K-Instruct (7B)",
        "together_ai/togethercomputer/Llama-2-7B-32K-Instruct",
        None,
        "Together AI",
        context_length=32768,
    ),
    Model(
        "RedPajama-INCITE Chat (3B)",
        "together_ai/togethercomputer/RedPajama-INCITE-Chat-3B-v1",
        None,
        "Together AI",
        size_billion_parameters=3,
        context_length=2048,
    ),
    Model(
        "RedPajama-INCITE Chat (7B)",
        "together_ai/togethercomputer/RedPajama-INCITE-7B-Chat",
        None,
        "Together AI",
        context_length=2048,
        size_billion_parameters=7,
        # selected=True,
    ),
    Model(
        "StripedHyena Nous (7B)",
        "together_ai/togethercomputer/StripedHyena-Nous-7B",
        None,
        "Together AI",
        context_length=32768,
        size_billion_parameters=7,
    ),
    Model(
        "ReMM SLERP L2 (13B)",
        "together_ai/Undi95/ReMM-SLERP-L2-13B",
        None,
        "Together AI",
        context_length=4096,
        size_billion_parameters=13,
    ),
    Model(
        "Toppy M (7B)",
        "together_ai/Undi95/Toppy-M-7B",
        None,
        "Together AI",
        context_length=4096,
        size_billion_parameters=7,
    ),
    Model(
        "WizardLM v1.2 (13B)",
        "together_ai/WizardLM/WizardLM-13B-V1.2",
        None,
        "Together AI",
        context_length=4096,
        size_billion_parameters=13,
        # selected=True,
    ),
    Model(
        "Upstage SOLAR Instruct v1 (11B)",
        "together_ai/upstage/SOLAR-10.7B-Instruct-v1.0",
        None,
        "Together AI",
        context_length=4096,
        size_billion_parameters=11,
        # selected=True,
    ),
]

MISSING_MEASUREMENTS = [
    "gemini-pro",
    "chat-bison (PaLM 2)",
    "chat-bison-32k (PaLM 2 32K)"
]

def together_ai_price(billions_of_parameters):
    if billions_of_parameters <= 4:
        return 0.1
    elif 4.1 <= billions_of_parameters <= 8:
        return 0.2
    elif 8.1 <= billions_of_parameters <= 21:
        return 0.3
    elif 21.1 <= billions_of_parameters <= 41:
        return 0.8
    elif 41 < billions_of_parameters: # <= 70:
        return 0.9

for model in MODELS:
    if model.model_display_name in MISSING_MEASUREMENTS:
        model.selected = True
    
    if model.provider == "Together AI":
        pattern = r"\((\d+(\.\d+)?)B\)"
        match = re.search(pattern, model.model_display_name)
        if match:
            number_in_parenthesis = float(match.group(1))
            model.size_billion_parameters = number_in_parenthesis
            model.cost_per_million_tokens = together_ai_price(model.size_billion_parameters)
            model.__post_init__()