File size: 3,365 Bytes
8bf4dee 0e936e1 8bf4dee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import argparse
import subprocess
import wandb
import wandb.apis.public
from collections import defaultdict
from multiprocessing.pool import ThreadPool
from typing import List, NamedTuple
class RunGroup(NamedTuple):
algo: str
env_id: str
def benchmark_publish() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--wandb-project-name",
type=str,
default="rl-algo-impls-benchmarks",
help="WandB project name to load runs from",
)
parser.add_argument(
"--wandb-entity",
type=str,
default=None,
help="WandB team of project. None uses default entity",
)
parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
parser.add_argument(
"--envs", type=str, nargs="*", help="Optional filter down to these envs"
)
parser.add_argument(
"--exclude-envs",
type=str,
nargs="*",
help="Environments to exclude from publishing",
)
parser.add_argument(
"--huggingface-user",
type=str,
default=None,
help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
)
parser.add_argument(
"--pool-size",
type=int,
default=3,
help="How many publish jobs can run in parallel",
)
parser.add_argument(
"--virtual-display", action="store_true", help="Use headless virtual display"
)
# parser.set_defaults(
# wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
# envs=[],
# exclude_envs=[],
# )
args = parser.parse_args()
print(args)
api = wandb.Api()
all_runs = api.runs(
f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
)
required_tags = set(args.wandb_tags)
runs: List[wandb.apis.public.Run] = [
r
for r in all_runs
if required_tags.issubset(set(r.config.get("wandb_tags", [])))
]
runs_paths_by_group = defaultdict(list)
for r in runs:
if r.state != "finished":
continue
algo = r.config["algo"]
env = r.config["env"]
if args.envs and env not in args.envs:
continue
if args.exclude_envs and env in args.exclude_envs:
continue
run_group = RunGroup(algo, env)
runs_paths_by_group[run_group].append("/".join(r.path))
def run(run_paths: List[str]) -> None:
publish_args = ["python", "huggingface_publish.py"]
publish_args.append("--wandb-run-paths")
publish_args.extend(run_paths)
publish_args.append("--wandb-report-url")
publish_args.append(args.wandb_report_url)
if args.huggingface_user:
publish_args.append("--huggingface-user")
publish_args.append(args.huggingface_user)
if args.virtual_display:
publish_args.append("--virtual-display")
subprocess.run(publish_args)
tp = ThreadPool(args.pool_size)
for run_paths in runs_paths_by_group.values():
tp.apply_async(run, (run_paths,))
tp.close()
tp.join()
if __name__ == "__main__":
benchmark_publish()
|