|
|
|
import argparse |
|
from huggingface_hub import HfApi |
|
|
|
|
|
def main(api, model_id): |
|
info = api.list_repo_refs(model_id) |
|
branches = set([b.name for b in info.branches]) - set(["main"]) |
|
|
|
return list(branches) |
|
|
|
|
|
if __name__ == "__main__": |
|
DESCRIPTION = """ |
|
Simple utility to get all branches from a repo |
|
""" |
|
parser = argparse.ArgumentParser(description=DESCRIPTION) |
|
parser.add_argument( |
|
"--model_id", |
|
type=str, |
|
help="The name of the model on the hub to retrieve the branches from. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", |
|
) |
|
|
|
args = parser.parse_args() |
|
model_id = args.model_id |
|
api = HfApi() |
|
branches = main(api, model_id) |
|
|
|
if "non-ema" in branches: |
|
print(model_id) |
|
|
|
|
|
|
|
|