import os
import re
import sys
import time
import random
import yaml
import subprocess
from io import StringIO

import runpod
import shutil
import requests
import gradio as gr
import pandas as pd
from jinja2 import Template
from huggingface_hub import ModelCard, ModelCardData, HfApi, repo_info
from huggingface_hub.utils import RepositoryNotFoundError

# Set environment variables
HF_TOKEN = os.environ.get("HF_TOKEN")
runpod.api_key = os.environ.get("RUNPOD_TOKEN")

# Parameters
USERNAME = 'automerger'
N_ROWS = 15
WAIT_TIME = 10800


# Logger from https://github.com/gradio-app/gradio/issues/2362
class Logger:
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, "w")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        
    def flush(self):
        self.terminal.flush()
        self.log.flush()
        
    def isatty(self):
        return False    


def read_logs():
    sys.stdout.flush()
    with open("output.log", "r") as f:
        return f.read()

        
def create_dataset() -> bool:
    """
    Use Scrape Open LLM Leaderboard to create a CSV dataset.
    """
    command = ["python3", "scrape-open-llm-leaderboard/main.py", "-csv"]

    try:
        result = subprocess.run(command, check=True, stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE, text=True)
        print(f"scrape-open-llm-leaderboard: {result.stdout}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"scrape-open-llm-leaderboard: {e.stderr}")
        return False


def merge_models() -> None:
    """
    Use mergekit to create a merge.
    """
    command = ["mergekit-yaml", "config.yaml", "/data/merge", "--copy-tokenizer", "--transformers_cache", "/data"]

    with open("output.log", "a") as log_file: 
        try:
            result = subprocess.run(command, check=True, stdout=log_file,
                                    stderr=log_file, text=True)
            print(f"mergekit: {result.stdout}")
        except subprocess.CalledProcessError as e:
            print(f"Error: mergekit {command}: {e.stderr}")


def make_df(file_path: str, n_rows: int) -> pd.DataFrame:
    """
    Create a filtered dataset from the Open LLM Leaderboard.
    """
    columns = ["Available on the hub", "Model sha", "T", "Type", "Precision",
              "Architecture", "Weight type", "Hub ❀️", "Flagged", "MoE"]
    ds = pd.read_csv("open-llm-leaderboard.csv", encoding='utf-8')
    df = (
          ds[
            (ds["#Params (B)"] == 8)
            & (ds["Architecture"] == "LlamaForCausalLM")
            & (ds["Available on the hub"] == True)
            & (ds["Flagged"] == True)
            & (~ds["Model"].str.lower().str.contains("yi"))
            & (~ds["Model"].str.lower().str.contains("9b"))
            & (~ds["Model"].str.lower().str.contains("8xqmff94/slm"))
            & (ds["MoE"] == True)
            & (ds["Weight type"] == "Original")
          ]
          .drop(columns=columns)
          .drop_duplicates(subset=["Model"])
          .sort_values(by="MMLU", ascending=False)
          .iloc[:n_rows]
      )
    return df


def repo_exists(repo_id: str) -> bool:
    try:
        repo_info(repo_id)
        return True
    except RepositoryNotFoundError:
        return False


def get_name(models: list[pd.Series], username: str, version=0) -> str:
    model_name = models[0]["Model"].split("/")[-1].split("-")[0].capitalize() \
                 + models[1]["Model"].split("/")[-1].split("-")[0].capitalize() \
                 + "-8B"
    if version > 0:
        model_name = model_name.split("-")[0] + f"-v{version}-8B"

    if repo_exists(f"{username}/{model_name}"):
        get_name(models, username, version+1)

    return model_name
    

def get_license(models: list[pd.Series]) -> str:
    license1 = models[0]["Hub License"]
    license2 = models[1]["Hub License"]
    license = "cc-by-nc-4.0"

    if license1 == "cc-by-nc-4.0" or license2 == "cc-by-nc-4.0":
        license = "cc-by-nc-4.0"
    elif license1 == "apache-2.0" or license2 == "apache-2.0":
        license = "apache-2.0"
    elif license1 == "MIT" and license2 == "MIT":
        license = "MIT"
    return license


def create_config(models: list[pd.Series]) -> str:
    slerp_config = f"""
slices:
  - sources:
      - model: {models[0]["Model"]}
        layer_range: [0, 32]
      - model: {models[1]["Model"]}
        layer_range: [0, 32]
merge_method: slerp
base_model: {models[0]["Model"]}
parameters:
  t:
    - filter: self_attn
      value: [0, 0.5, 0.3, 0.7, 1]
    - filter: mlp
      value: [1, 0.5, 0.7, 0.3, 0]
    - value: 0.5
dtype: bfloat16
random_seed: 0
    """
    dare_config = f"""
models:
- model: {models[0]["Model"]}
  # No parameters necessary for base model
- model: {models[1]["Model"]}
  parameters:
    density: 0.53
    weight: 0.6
merge_method: dare_ties
base_model: {models[0]["Model"]}
parameters:
int8_mask: true
dtype: bfloat16
random_seed: 0
""" 
    stock_config = f"""
models:
  - model: meta-llama/Meta-Llama-3-8B
  - model: {models[0]["Model"]}
  - model: {models[1]["Model"]}
merge_method: model_stock
base_model: meta-llama/Meta-Llama-3-8B
dtype: bfloat16    
"""
    yaml_config = random.choices([slerp_config, dare_config, stock_config], weights=[0.3, 0.6, 0.1], k=1)[0]

    with open('config.yaml', 'w', encoding="utf-8") as f:
        f.write(yaml_config)

    return yaml_config


def create_model_card(yaml_config: str, model_name: str, username: str, license: str) -> None:
    template_text = """
---
license: {{ license }}
base_model:
{%- for model in models %}
  - {{ model }}
{%- endfor %}
tags:
- merge
- mergekit
- lazymergekit
- automerger
---

# {{ model_name }}

{{ model_name }} is an automated merge created by [Maxime Labonne](https://huggingface.co/mlabonne) using the following configuration.

{%- for model in models %}
* [{{ model }}](https://huggingface.co/{{ model }})
{%- endfor %}

## 🧩 Configuration

```yaml
{{- yaml_config -}}

```

## πŸ’» Usage

```python
!pip install -qU transformers accelerate

from transformers import AutoTokenizer
import transformers
import torch

model = "{{ username }}/{{ model_name }}"
messages = [{"role": "user", "content": "What is a large language model?"}]

tokenizer = AutoTokenizer.from_pretrained(model)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

outputs = pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])
```
"""

    # Create a Jinja template object
    jinja_template = Template(template_text.strip())

    # Get list of models from config
    data = yaml.safe_load(yaml_config)
    if "models" in data:
        models = [data["models"][i]["model"] for i in range(len(data["models"])) if "parameters" in data["models"][i]]
    elif "parameters" in data:
        models = [data["slices"][0]["sources"][i]["model"] for i in range(len(data["slices"][0]["sources"]))]
    elif "slices" in data:
        models = [data["slices"][i]["sources"][0]["model"] for i in range(len(data["slices"]))]
    else:
        raise Exception("No models or slices found in yaml config")

    # Fill the template
    content = jinja_template.render(
        model_name=model_name,
        models=models,
        yaml_config=yaml_config,
        username=username,
        license=license
    )

    # Save the model card
    card = ModelCard(content)
    card.save('/data/merge/README.md')


def upload_model(api: HfApi, username: str, model_name: str) -> None:
    """
    Upload merged model to the Hugging Face Hub.
    """
    api.create_repo(
        repo_id=f"{username}/{model_name}",
        repo_type="model",
        exist_ok=True,
    )
    api.upload_folder(
        repo_id=f"{username}/{model_name}",
        folder_path="/data/merge",
    )


def create_pod(model_name: str, username: str, n=10, wait_seconds=10):
    """
    Create a RunPod instance to run the evaluation.
    """
    for attempt in range(n):
        try:
            pod = runpod.create_pod(
                name=f"Automerge {model_name} on Nous",
                image_name="runpod/pytorch:2.0.1-py3.10-cuda11.8.0-devel-ubuntu22.04",
                gpu_type_id="NVIDIA GeForce RTX 3090",
                cloud_type="COMMUNITY",
                gpu_count=1,
                volume_in_gb=0,
                container_disk_in_gb=50,
                template_id="au6nz6emhk",
                env={
                    "BENCHMARK": "nous",
                    "MODEL_ID": f"{username}/{model_name}",
                    "REPO": "https://github.com/mlabonne/llm-autoeval.git",
                    "TRUST_REMOTE_CODE": False,
                    "PRIVATE_GIST": False,
                    "YALL_GIST_ID": "56ebbd012d942a6b749db5243de5740f",
                    "DEBUG": False,
                    "GITHUB_API_TOKEN": os.environ["GITHUB_TOKEN"],
                }
            )
            print("Evaluation started.")
            return pod
        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error: {e}")
            if attempt < n - 1:
                print(f"Waiting {wait_seconds} seconds before retrying...")
                time.sleep(wait_seconds)
            else:
                print("All attempts failed. Giving up.")
                raise


def download_leaderboard():
    """
    Download the gist that contains the leaderboard.
    """
    url = "https://gist.githubusercontent.com/automerger/56ebbd012d942a6b749db5243de5740f/raw"
    file_path = "leaderboard.txt"
    response = requests.get(url)
    return response.content.decode('utf-8')


def convert_markdown_table_to_dataframe(md_content):
    """
    Converts markdown table to Pandas DataFrame.
    """
    # Remove leading and trailing | characters
    cleaned_content = re.sub(r'\|\s*$', '', re.sub(r'^\|\s*', '', md_content, flags=re.MULTILINE), flags=re.MULTILINE)

    # Create DataFrame from cleaned content
    df = pd.read_csv(StringIO(cleaned_content), sep="\|", engine='python')

    # Remove the first row after the header
    df = df.drop(0, axis=0)

    # Strip whitespace from column names
    df.columns = df.columns.str.strip()

    return df


def get_dataframe():
    """
    Wrapper to update the Gradio dataframe.
    """
    content = download_leaderboard()
    df = convert_markdown_table_to_dataframe(content)
    return df


def clear_data():
    """
    Clear data so the Space doesn't crash...
    """
    dir_path = "/data"
    try:
        with os.scandir(dir_path) as entries:
            for entry in entries:
                if entry.is_file():
                    os.unlink(entry.path)
        print("All files deleted successfully.")
    except OSError:
        print("Error occurred while deleting files.")


def get_size(start_path):
    total_size = 0
    for dirpath, dirnames, filenames in os.walk(start_path):
        for f in filenames:
            fp = os.path.join(dirpath, f)
            # skip if it is symbolic link
            if not os.path.islink(fp):
                total_size += os.path.getsize(fp)
    return total_size


def human_readable_size(size, decimal_places=2):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB', 'PB']:
        if size < 1024.0:
            break
        size /= 1024.0
    return f"{size:.{decimal_places}f} {unit}"


def merge_loop():
    """
    Main function that orchestrates the merge.
    """
    # Start HF API
    api = HfApi(token=HF_TOKEN)

    # Create dataset (proceed only if successful)
    if not create_dataset():
        print("Failed to create dataset. Skipping merge loop.")
        return

    df = make_df("open-llm-leaderboard.csv", N_ROWS)
    assert not df.empty, "DataFrame is empty. Cannot proceed with merge loop."

    # Sample two models
    dir_path = "/data"
    sample = df.sample(n=2)
    models = [sample.iloc[i] for i in range(2)]

    # Get model name
    model_name = get_name(models, USERNAME, version=0)
    print("="*60)
    print(f"Model name: {model_name}")

    # Get model license
    license = get_license(models)
    print(f"License: {license}")

    # Merge configs
    yaml_config = create_config(models)
    print(f"YAML config:{yaml_config}")
    print(f"Data size: {human_readable_size(get_size(dir_path))}")

    # Merge models
    merge_models()
    print("Model merged!")

    # Create model card
    print("Create model card")
    create_model_card(yaml_config, model_name, USERNAME, license)

    # Upload model
    print("Upload model")
    upload_model(api, USERNAME, model_name)

    # Clear data
    print("Clear data")
    clear_data()
    
    # Evaluate model on Runpod
    print("Start evaluation")
    create_pod(model_name, USERNAME)
    print(f"Waiting for {WAIT_TIME/60} minutes...")

# Set the HF_DATASETS_CACHE environment variable
os.environ['HF_DATASETS_CACHE'] = "/data/hfcache/"

# Verify the environment variable is set
print(os.environ['HF_DATASETS_CACHE'])

# Install scrape-open-llm-leaderboard and mergekit
command = ["git", "clone", "-q", "https://github.com/Weyaxi/scrape-open-llm-leaderboard"]
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

command = ["pip", "install", "-r", "scrape-open-llm-leaderboard/requirements.txt"]
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
  
command = ["git", "clone", "https://github.com/arcee-ai/mergekit.git"]
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

command = ["pip", "install", "-e", "mergekit"]
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

sys.stdout = Logger("output.log")

# Gradio interface
title = """
<div align="center">
  <p style="font-size: 44px;">♾️ AutoMerger</p>
  <p style="font-size: 20px;">πŸ“ƒ <a href="https://huggingface.co/automerger">Merged models</a>β€€β€’β€€πŸ† <a href="https://huggingface.co/spaces/automerger/Yet_Another_LLM_Leaderboard">Leaderboard</a>β€€β€’β€€πŸ“ <a href="https://huggingface.co/blog/mlabonne/merge-models">Article</a>β€€β€’β€€πŸ¦β€€<a href="https://twitter.com/maximelabonne">Follow me on X</a></p>
  <p><em>AutoMerger selects two Llama 3 8B models on top of the Open LLM Leaderboard, combine them with a merge technique, and evaluate the resulting model.</em></p>
</div>
"""
footer = '<div align="center"><p><em>Special thanks to <a href="https://huggingface.co/Weyaxi">Weyaxi</a> for the <a href="https://github.com/Weyaxi/scrape-open-llm-leaderboard">Open LLM Leaderboard Scraper</a>, <a href="https://github.com/cg123">Charles Goddard</a> for <a href="https://github.com/arcee-ai/mergekit">mergekit</a>, and <a href="https://huggingface.co/MaziyarPanahi">Maziyar Panahi</a> for making <a href="https://huggingface.co/collections/MaziyarPanahi/gguf-65afc99c3997c4b6d2d9e1d5">GGUF versions</a> of these automerges.</em></p></div>'
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
    gr.Markdown(title)
    logs = gr.Textbox(label="Logs")
    demo.load(read_logs, None, logs, every=10)
    leaderboard = gr.Dataframe(value=get_dataframe, datatype=["markdown", "number", "number", "number", "number", "number"], every=3600)
    gr.Markdown(footer)
demo.queue(default_concurrency_limit=50).launch(server_name="0.0.0.0", show_error=True, prevent_thread_lock=True)

print("Start AutoMerger...")

# Main loop
while True:
    merge_loop()
    time.sleep(WAIT_TIME)