zhangbo2008's picture
Duplicate from facebook/ov-seg
7e8c559
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import os
import wandb
from detectron2.utils import comm
from detectron2.utils.events import EventWriter, get_event_storage
def setup_wandb(cfg, args):
if comm.is_main_process():
init_args = {
k.lower(): v
for k, v in cfg.WANDB.items()
if isinstance(k, str) and k not in ["config", "name"]
}
# only include most related part to avoid too big table
# TODO: add configurable params to select which part of `cfg` should be saved in config
if "config_exclude_keys" in init_args:
init_args["config"] = cfg
init_args["config"]["cfg_file"] = args.config_file
else:
init_args["config"] = {
"model": cfg.MODEL,
"solver": cfg.SOLVER,
"cfg_file": args.config_file,
}
if ("name" not in init_args) or (init_args["name"] is None):
init_args["name"] = os.path.basename(args.config_file)
wandb.init(**init_args)
class BaseRule(object):
def __call__(self, target):
return target
class IsIn(BaseRule):
def __init__(self, keyword: str):
self.keyword = keyword
def __call__(self, target):
return self.keyword in target
class Prefix(BaseRule):
def __init__(self, keyword: str):
self.keyword = keyword
def __call__(self, target):
return "/".join([self.keyword, target])
class WandbWriter(EventWriter):
"""
Write all scalars to a tensorboard file.
"""
def __init__(self):
"""
Args:
log_dir (str): the directory to save the output events
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
"""
self._last_write = -1
self._group_rules = [
(IsIn("/"), BaseRule()),
(IsIn("loss"), Prefix("train")),
]
def write(self):
storage = get_event_storage()
def _group_name(scalar_name):
for (rule, op) in self._group_rules:
if rule(scalar_name):
return op(scalar_name)
return scalar_name
stats = {
_group_name(name): scalars[0]
for name, scalars in storage.latest().items()
if scalars[1] > self._last_write
}
if len(stats) > 0:
self._last_write = max([v[1] for k, v in storage.latest().items()])
# storage.put_{image,histogram} is only meant to be used by
# tensorboard writer. So we access its internal fields directly from here.
if len(storage._vis_data) >= 1:
stats["image"] = [
wandb.Image(img, caption=img_name)
for img_name, img, step_num in storage._vis_data
]
# Storage stores all image data and rely on this writer to clear them.
# As a result it assumes only one writer will use its image data.
# An alternative design is to let storage store limited recent
# data (e.g. only the most recent image) that all writers can access.
# In that case a writer may not see all image data if its period is long.
storage.clear_images()
if len(storage._histograms) >= 1:
def create_bar(tag, bucket_limits, bucket_counts, **kwargs):
data = [
[label, val] for (label, val) in zip(bucket_limits, bucket_counts)
]
table = wandb.Table(data=data, columns=["label", "value"])
return wandb.plot.bar(table, "label", "value", title=tag)
stats["hist"] = [create_bar(**params) for params in storage._histograms]
storage.clear_histograms()
if len(stats) == 0:
return
wandb.log(stats, step=storage.iter)
def close(self):
wandb.finish()