|
|
|
|
|
""" |
|
A script to benchmark builtin models. |
|
|
|
Note: this script has an extra dependency of psutil. |
|
""" |
|
|
|
import itertools |
|
import logging |
|
import psutil |
|
import torch |
|
import tqdm |
|
from fvcore.common.timer import Timer |
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.config import get_cfg |
|
from detectron2.data import ( |
|
DatasetFromList, |
|
build_detection_test_loader, |
|
build_detection_train_loader, |
|
) |
|
from detectron2.engine import SimpleTrainer, default_argument_parser, hooks, launch |
|
from detectron2.modeling import build_model |
|
from detectron2.solver import build_optimizer |
|
from detectron2.utils import comm |
|
from detectron2.utils.events import CommonMetricPrinter |
|
from detectron2.utils.logger import setup_logger |
|
|
|
logger = logging.getLogger("detectron2") |
|
|
|
|
|
def setup(args): |
|
cfg = get_cfg() |
|
cfg.merge_from_file(args.config_file) |
|
cfg.SOLVER.BASE_LR = 0.001 |
|
cfg.merge_from_list(args.opts) |
|
cfg.freeze() |
|
setup_logger(distributed_rank=comm.get_rank()) |
|
return cfg |
|
|
|
|
|
def benchmark_data(args): |
|
cfg = setup(args) |
|
|
|
timer = Timer() |
|
dataloader = build_detection_train_loader(cfg) |
|
logger.info("Initialize loader using {} seconds.".format(timer.seconds())) |
|
|
|
timer.reset() |
|
itr = iter(dataloader) |
|
for i in range(10): |
|
next(itr) |
|
if i == 0: |
|
startup_time = timer.seconds() |
|
timer = Timer() |
|
max_iter = 1000 |
|
for _ in tqdm.trange(max_iter): |
|
next(itr) |
|
logger.info( |
|
"{} iters ({} images) in {} seconds.".format( |
|
max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds() |
|
) |
|
) |
|
logger.info("Startup time: {} seconds".format(startup_time)) |
|
vram = psutil.virtual_memory() |
|
logger.info( |
|
"RAM Usage: {:.2f}/{:.2f} GB".format( |
|
(vram.total - vram.available) / 1024 ** 3, vram.total / 1024 ** 3 |
|
) |
|
) |
|
|
|
|
|
for _ in range(10): |
|
timer = Timer() |
|
max_iter = 1000 |
|
for _ in tqdm.trange(max_iter): |
|
next(itr) |
|
logger.info( |
|
"{} iters ({} images) in {} seconds.".format( |
|
max_iter, max_iter * cfg.SOLVER.IMS_PER_BATCH, timer.seconds() |
|
) |
|
) |
|
|
|
|
|
def benchmark_train(args): |
|
cfg = setup(args) |
|
model = build_model(cfg) |
|
logger.info("Model:\n{}".format(model)) |
|
if comm.get_world_size() > 1: |
|
model = DistributedDataParallel( |
|
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False |
|
) |
|
optimizer = build_optimizer(cfg, model) |
|
checkpointer = DetectionCheckpointer(model, optimizer=optimizer) |
|
checkpointer.load(cfg.MODEL.WEIGHTS) |
|
|
|
cfg.defrost() |
|
cfg.DATALOADER.NUM_WORKERS = 0 |
|
data_loader = build_detection_train_loader(cfg) |
|
dummy_data = list(itertools.islice(data_loader, 100)) |
|
|
|
def f(): |
|
data = DatasetFromList(dummy_data, copy=False) |
|
while True: |
|
yield from data |
|
|
|
max_iter = 400 |
|
trainer = SimpleTrainer(model, f(), optimizer) |
|
trainer.register_hooks( |
|
[hooks.IterationTimer(), hooks.PeriodicWriter([CommonMetricPrinter(max_iter)])] |
|
) |
|
trainer.train(1, max_iter) |
|
|
|
|
|
@torch.no_grad() |
|
def benchmark_eval(args): |
|
cfg = setup(args) |
|
model = build_model(cfg) |
|
model.eval() |
|
logger.info("Model:\n{}".format(model)) |
|
DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) |
|
|
|
cfg.defrost() |
|
cfg.DATALOADER.NUM_WORKERS = 0 |
|
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) |
|
dummy_data = list(itertools.islice(data_loader, 100)) |
|
|
|
def f(): |
|
while True: |
|
yield from DatasetFromList(dummy_data, copy=False) |
|
|
|
for _ in range(5): |
|
model(dummy_data[0]) |
|
|
|
max_iter = 400 |
|
timer = Timer() |
|
with tqdm.tqdm(total=max_iter) as pbar: |
|
for idx, d in enumerate(f()): |
|
if idx == max_iter: |
|
break |
|
model(d) |
|
pbar.update() |
|
logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds())) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = default_argument_parser() |
|
parser.add_argument("--task", choices=["train", "eval", "data"], required=True) |
|
args = parser.parse_args() |
|
assert not args.eval_only |
|
|
|
if args.task == "data": |
|
f = benchmark_data |
|
elif args.task == "train": |
|
""" |
|
Note: training speed may not be representative. |
|
The training cost of a R-CNN model varies with the content of the data |
|
and the quality of the model. |
|
""" |
|
f = benchmark_train |
|
elif args.task == "eval": |
|
f = benchmark_eval |
|
|
|
assert args.num_gpus == 1 and args.num_machines == 1 |
|
launch(f, args.num_gpus, args.num_machines, args.machine_rank, args.dist_url, args=(args,)) |
|
|