import argparse
import json
import pathlib
import os

parser = argparse.ArgumentParser(
    description="Format the output of the data card tool as .md for the hub."
)
parser.add_argument("--input_path", "-i", type=pathlib.Path, required=False)
parser.add_argument("--output_path", "-o", type=pathlib.Path, required=False)
args = parser.parse_args()


def read_json_file(json_path: pathlib.Path):
    """Load a json file and return it as object."""
    with open(json_path, "r") as f:
        data = json.load(f)
    return data


def save_file(json_path: pathlib.Path, json_obj: str):
    """Takes a string and saves it as .md file."""
    with open(json_path, "w") as f:
        f.write(json.dumps(json_obj, indent=2))


def construct_json(dataset_name: str, data_card_data: dict, text_by_key: dict):
  """Constructs the json file

  This function iterates through text_by_key and extracts all answers from
  the data_card_data object. It uses the levels of hierarchy as indicator for
  the heading indentation and does not change the order in which anything
  appears.

  Args:
      data_card_data: Output from the data card tool
      text_by_key: configuration defined in key_to_question.json

  Returns:
      data_card_md_string: json content
  """

  try:
    website_link = data_card_data["overview"]["where"]["website"]
  except KeyError:
    website_link = ""
  try:
    paper_link = data_card_data["overview"]["where"]["paper-url"]
  except KeyError:
    paper_link = ""
  try:
    authors = data_card_data["overview"]["credit"]["creators"]
  except KeyError:
    authors = ""
  try:
    summary = data_card_data["overview"]["what"]["dataset"]
  except KeyError:
    summary = "Placeholder"


  # Add summary blurb with loading script and link to GEM loader
  summary +=f"\n\nYou can load the dataset via:\n```\nimport datasets\ndata = datasets.load_dataset('GEM/{dataset_name}')\n```\nThe data loader can be found [here](https://huggingface.co/datasets/GEM/{dataset_name})."

  new_json = {
      "name": dataset_name,
      "summary": summary,
      "sections": [
      ],
  }

  if website_link:
    new_json["website"] = website_link
  if paper_link:
    new_json["paper"] = paper_link
  if authors:
    new_json["authors"] = authors


  total_questions = 0
  total_words = 0

  for main_key, main_content in text_by_key.items():
    l2_data = {
              "title": main_content["section-title"],
              "level": 2,
              "subsections": []
    }
    if main_key not in data_card_data:
      continue
    for second_key, second_content in main_content.items():
      if second_key == "section-title":
        continue
      # Skip summary data since it is already in the header.
      if main_key == "overview" and second_key == "what":
        continue
      l3_data = {
                      "title": second_content["section-title"],
                      "level": 3,
                      "fields": []
      }
      for final_key, final_content in second_content.items():
        if final_key == "section-title":
          continue
        try:
          total_questions += 1
          answer = data_card_data[main_key][second_key].get(final_key, "N/A")
        except:
          # print(main_key, second_key, final_key)
          # print("==="*50)
          # print(data_card_data)
          continue
        # Skip empty answers.
        if isinstance(answer, str):
          if answer.lower() == "n/a":
            continue
        if not answer:
          continue

        if isinstance(answer, list):
          answer = ", ".join([f"`{a}`" for a in answer])

        json_answer = {
          "title": final_content["title"],
          "level": 4,
          "content": answer,
          "flags": final_content["flags"],
          "info": final_content["info"],
          "scope": final_content["scope"],
        }
        total_words += len(answer.split())
        l3_data["fields"].append(json_answer)
      l2_data["subsections"].append(l3_data)
    new_json["sections"].append(l2_data)
  print(f"Total questions {total_questions}")
  print(f"total words: {total_words}")
  return new_json, total_words




if __name__ == "__main__":

  text_by_key = read_json_file(
      os.path.join(os.path.dirname(__file__), "key_to_question.json")
  )
  total_words_across_everything = 0
  for dataset in os.listdir("../../../GEMv2"):
    data_card_path = f"../../../GEMv2/{dataset}/{dataset}.json"
    if os.path.exists(data_card_path):
      print(f"Now processing {dataset}.")
      new_path = f"datacards/{dataset}.json"
      data_card_data = read_json_file(data_card_path)
      data_card_json, total_cur_words = construct_json(dataset, data_card_data, text_by_key)
      total_words_across_everything += total_cur_words

      save_file(new_path, data_card_json)
    else:
      print(f"{dataset} has no data card!")
  print(total_words_across_everything)
  # data_card_json = construct_json(data_card_data, text_by_key)
  # save_file(args.output_path, data_card_json)