#!/usr/bin/env python3
import argparse
from huggingface_hub import HfApi


def main(api, model_id):
    info = api.list_repo_refs(model_id)
    branches = set([b.name for b in info.branches]) - set(["main"])

    return list(branches)


if __name__ == "__main__":
    DESCRIPTION = """
    Simple utility to get all branches from a repo
    """
    parser = argparse.ArgumentParser(description=DESCRIPTION)
    parser.add_argument(
        "--model_id",
        type=str,
        help="The name of the model on the hub to retrieve the branches from. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
    )

    args = parser.parse_args()
    model_id = args.model_id
    api = HfApi()
    branches = main(api, model_id)

    if "non-ema" in branches:
        print(model_id)
#
#    if len(branches) > 0:
#        print(f"{model_id}: {branches}")