import argparse
import os
from datasets import load_dataset, Dataset
from huggingface_hub import HfApi

TOKEN = os.environ.get("DEBUG")
api = HfApi(token=TOKEN)

REQUESTS_DSET = "AIEnergyScore/requests_debug"
RESULTS_DSET = "AIEnergyScore/results_debug"
PENDING = 'PENDING'
COMPLETED = 'COMPLETED'
FAILED = 'FAILED'

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--run_dir",
        default="/runs",
        type=str,
        required=False,
        help="Path to the run directory.",
    )
    parser.add_argument(
        "--attempts",
        default="/attempts.txt",
        type=str,
        required=False,
        help="File with per-line run attempt directories. Assumes format '/runs/{task}/{model}/{timestamp}'",
    )
    parser.add_argument(
        "--failed_attempts",
        default="/failed_attempts.txt",
        type=str,
        required=False,
        help="File with per-line failed run directories. Assumes format '/runs/{task}/{model}/{timestamp}'",
    )
    args = parser.parse_args()
    return args

def check_for_traceback(run_dir):
    # run_dir="./runs/${experiment_name}/${backend_model}/${now}"
    found_error = False
    error_message = ""
    try:
        # Read error message
        with open(f"{run_dir}/error.log", 'r') as f:
            # There may be a better way to do this that finds the
            # index of Traceback, then prints from there : end-of-file index (the file length-1).
            for line in f:
                # Question: Do we even need to check for this? The presence of the
                # error file, or at least a non-empty one,
                # means there's been an error, no?
                if 'Traceback (most recent call last):' in line:
                    found_error = True
                if found_error:
                    error_message += line
    except FileNotFoundError as e:
        # When does this happen?
        print(f"Could not find {run_dir}/error.log")
    return error_message

def update_requests(requests, all_attempts, failed_attempts):
    """
     Sets All PENDING requests with the given model & task to 'COMPLETED' or 'FAILED.'
     Reads in the all_attempts text file and failed_attempts text file, in which
      each line is a run directory run_dir="/runs/${experiment_name}/${backend_model}/${now}"

    :param requests: requests Dataset
    :param all_attempts: text file of the run directories of each task/model/timestamp
    :param failed_attempts: text file of the run directories of each task/model/timestamp
    :return:
    """
    requests_df = requests.to_pandas()
    # Each line is a run directory, where
    # run_dir="/runs/${experiment_name}/${backend_model}/${now}", where
    # ${backend_model} is ${organization}/${model_name}
    for line in all_attempts:
        line = line.strip()
        print(f"Checking {line}")
        split_run_dir = line.strip().strip("/").split("/")
        print(f"Processing run directory {split_run_dir}")
        task = split_run_dir[1]
        print(f"Task is {task}")
        # The naming of the optimum benchmark configs uses an underscore.
        # The naming of the HF Api list models function uses a hyphen.
        # We therefore need to adapt this task string name depending on
        # which part of our pipeline we're talking to.
        hyphenated_task_name = "-".join(task.split("_"))
        model = "/".join([split_run_dir[2], split_run_dir[3]])
        print(f"Model is {model}")
        traceback_error = check_for_traceback(line)
        if traceback_error != "":
            print("Found a traceback error!")
            print(traceback_error)
            requests_df.loc[(requests_df["status"] == PENDING) & (requests_df["model"] == model) & (requests_df["task"] == hyphenated_task_name), ['status']] = FAILED
            requests_df.loc[(requests_df["status"] == PENDING) & (requests_df["model"] == model) & (requests_df["task"] == hyphenated_task_name), ['error_message']] = traceback_error
        elif line in failed_attempts:
            print(f"Job failed, but not sure why -- didn't find a traceback in {line}.")
            print(f"Setting {model}, {hyphenated_task_name}, status {PENDING} to {FAILED}.")
            print(requests_df[(requests_df["status"] == PENDING) & (requests_df["model"] == model) & (requests_df["task"] == hyphenated_task_name)])
            requests_df.loc[(requests_df["status"] == PENDING) & (requests_df["model"] == model) & (requests_df["task"] == hyphenated_task_name), ['status']] = FAILED
        else:
            requests_df.loc[(requests_df["status"] == PENDING) & (requests_df["model"] == model) & (requests_df["task"] == hyphenated_task_name), ['status']] = COMPLETED
    updated_dset = Dataset.from_pandas(requests_df)
    return updated_dset

if __name__ == '__main__':
    args = parse_args()
    # Uploads all run output to the results dataset.
    print(f"Uploading {args.run_dir} to {RESULTS_DSET}")
    api.upload_folder(
        folder_path=args.run_dir,
        repo_id=f"{RESULTS_DSET}",
        repo_type="dataset",
    )
    # Update requests dataset based on whether things have failed or not.
    print(f"Examining the run directory for each model & task to determine if it {FAILED} or {COMPLETED}.")
    requests = load_dataset(f"{REQUESTS_DSET}", split="test", token=TOKEN)
    all_attempts = open(f"{args.attempts}", "r+").readlines()
    failed_attempts = open(f"{args.failed_attempts}", "r+").readlines()
    updated_requests = update_requests(requests, all_attempts, failed_attempts)
    print(f"Uploading updated {REQUESTS_DSET}.")
    updated_requests.push_to_hub(f"{REQUESTS_DSET}", split="test", token=TOKEN)
    print("Done.")