climateGAN / tests /test_trainer.py
vict0rsch's picture
initial commit from cc-ai/climateGAN
448ebbd
print("Imports...", end="", flush=True)
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import atexit
import logging
from argparse import ArgumentParser
from copy import deepcopy
import comet_ml
import climategan
from comet_ml.api import API
from climategan.trainer import Trainer
from climategan.utils import get_comet_rest_api_key
logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
import traceback
print("Done.")
def set_opts(opts, str_nested_key, value):
"""
Changes an opts with nested keys:
set_opts(addict.Dict(), "a.b.c", 2) == Dict({"a":{"b": {"c": 2}}})
Args:
opts (addict.Dict): opts whose values should be changed
str_nested_key (str): nested keys joined on "."
value (any): value to set to the nested keys of opts
"""
keys = str_nested_key.split(".")
o = opts
for k in keys[:-1]:
o = o[k]
o[keys[-1]] = value
def set_conf(opts, conf):
"""
Updates opts according to a test scenario's configuration dict.
Ignores all keys starting with "__" which are used for the scenario
but outside the opts
Args:
opts (addict.Dict): trainer options
conf (dict): scenario's configuration
"""
for k, v in conf.items():
if k.startswith("__"):
continue
set_opts(opts, k, v)
class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
class Colors:
def _r(self, key, *args):
return f"{key}{' '.join(args)}{bcolors.ENDC}"
def ob(self, *args):
return self._r(bcolors.OKBLUE, *args)
def w(self, *args):
return self._r(bcolors.WARNING, *args)
def og(self, *args):
return self._r(bcolors.OKGREEN, *args)
def f(self, *args):
return self._r(bcolors.FAIL, *args)
def b(self, *args):
return self._r(bcolors.BOLD, *args)
def u(self, *args):
return self._r(bcolors.UNDERLINE, *args)
def comet_handler(exp, api):
def sub_handler():
p = Colors()
print()
print(p.b(p.w("Deleting comet experiment")))
api.delete_experiment(exp.get_key())
return sub_handler
def print_start(desc):
p = Colors()
cdesc = p.b(p.ob(desc))
title = "| " + cdesc + " |"
line = "-" * (len(desc) + 6)
print(f"{line}\n{title}\n{line}")
def print_end(desc=None, ok=None):
p = Colors()
if ok and desc is None:
desc = "Done"
cdesc = p.b(p.og(desc))
elif not ok and desc is None:
desc = "! Fail !"
cdesc = p.b(p.f(desc))
elif desc is not None:
cdesc = p.b(p.og(desc))
else:
desc = "Unknown"
cdesc = desc
title = "| " + cdesc + " |"
line = "-" * (len(desc) + 6)
print(f"{line}\n{title}\n{line}\n")
def delete_on_exit(exp):
"""
Registers a callback to delete the comet exp at program exit
Args:
exp (comet_ml.Experiment): The exp to delete
"""
rest_api_key = get_comet_rest_api_key()
api = API(api_key=rest_api_key)
atexit.register(comet_handler(exp, api))
if __name__ == "__main__":
# -----------------------------
# ----- Parse Arguments -----
# -----------------------------
parser = ArgumentParser()
parser.add_argument("--no_delete", action="store_true", default=False)
parser.add_argument("--no_end_to_end", action="store_true", default=False)
parser.add_argument("--include", "-i", nargs="+", default=[])
parser.add_argument("--exclude", "-e", nargs="+", default=[])
args = parser.parse_args()
assert not (args.include and args.exclude), "Choose 1: include XOR exclude"
include = set(int(i) for i in args.include)
exclude = set(int(i) for i in args.exclude)
if include:
print("Including exclusively tests", " ".join(args.include))
if exclude:
print("Excluding tests", " ".join(args.exclude))
# --------------------------------------
# ----- Create global experiment -----
# --------------------------------------
print("Creating comet Experiment...", end="", flush=True)
global_exp = comet_ml.Experiment(
project_name="climategan-test", display_summary_level=0
)
print("Done.")
if not args.no_delete:
delete_on_exit(global_exp)
# prompt util for colors
prompt = Colors()
# -------------------------------------
# ----- Base Test Scenario Opts -----
# -------------------------------------
print("Loading opts...", end="", flush=True)
base_opts = climategan.utils.load_opts()
base_opts.data.check_samples = False
base_opts.train.fid.n_images = 5
base_opts.comet.display_size = 5
base_opts.tasks = ["m", "s", "d"]
base_opts.domains = ["r", "s"]
base_opts.data.loaders.num_workers = 4
base_opts.data.loaders.batch_size = 2
base_opts.data.max_samples = 9
base_opts.train.epochs = 1
if isinstance(base_opts.data.transforms[-1].new_size, int):
base_opts.data.transforms[-1].new_size = 256
else:
base_opts.data.transforms[-1].new_size.default = 256
print("Done.")
# --------------------------------------
# ----- Configure Test Scenarios -----
# --------------------------------------
# override any nested key in opts
# create scenario-specific variables with __key
# ALWAYS specify a __doc key to describe your scenario
test_scenarios = [
{"__use_comet": False, "__doc": "MSD no exp", "__verbose": 1}, # 0
{"__doc": "MSD with exp"}, # 1
{
"__doc": "MSD no exp upsample_featuremaps", # 2
"__use_comet": False,
"gen.d.upsample_featuremaps": True,
"gen.s.upsample_featuremaps": True,
},
{"tasks": ["p"], "domains": ["rf"], "__doc": "Painter"}, # 3
{
"__doc": "M no exp low level feats", # 4
"__use_comet": False,
"gen.m.use_low_level_feats": True,
"gen.m.use_dada": False,
"tasks": ["m"],
},
{
"__doc": "MSD no exp deeplabv2", # 5
"__use_comet": False,
"gen.encoder.architecture": "deeplabv2",
"gen.s.architecture": "deeplabv2",
},
{
"__doc": "MSDP no End-to-end", # 6
"domains": ["rf", "r", "s"],
"tasks": ["m", "s", "d", "p"],
},
{
"__doc": "MSDP inference only no exp", # 7
"__inference": True,
"__use_comet": False,
"domains": ["rf", "r", "s"],
"tasks": ["m", "s", "d", "p"],
},
{
"__doc": "MSDP with End-to-end", # 8
"__pl4m": True,
"domains": ["rf", "r", "s"],
"tasks": ["m", "s", "d", "p"],
},
{
"__doc": "Kitti pretrain", # 9
"train.epochs": 2,
"train.kitti.pretrain": True,
"train.kitti.epochs": 1,
"domains": ["kitti", "r", "s"],
"train.kitti.batch_size": 2,
},
{"__doc": "Depth Dada archi", "gen.d.architecture": "dada"}, # 10
{
"__doc": "Depth Base archi",
"gen.d.architecture": "base",
"gen.m.use_dada": False,
"gen.s.use_dada": False,
}, # 11
{
"__doc": "Depth Base Classification", # 12
"gen.d.architecture": "base",
"gen.d.classify.enable": True,
"gen.m.use_dada": False,
"gen.s.use_dada": False,
},
{
"__doc": "MSD Resnet V3+ backbone",
"gen.deeplabv3.backbone": "resnet",
}, # 13
{
"__use_comet": False,
"__doc": "MSD SPADE 12 (without x)",
"__verbose": 1,
"gen.m.use_spade": True,
"gen.m.spade.cond_nc": 12,
}, # 14
{
"__use_comet": False,
"__doc": "MSD SPADE 15 (with x)",
"__verbose": 1,
"gen.m.use_spade": True,
"gen.m.spade.cond_nc": 15,
}, # 15
{
"__use_comet": False,
"__doc": "Painter With Diff Augment",
"__verbose": 1,
"domains": ["rf"],
"tasks": ["p"],
"gen.p.diff_aug.use": True,
}, # 15
{
"__use_comet": False,
"__doc": "MSD DADA_s",
"__verbose": 1,
"gen.s.use_dada": True,
"gen.m.use_dada": False,
}, # 16
{
"__use_comet": False,
"__doc": "MSD DADA_ms",
"__verbose": 1,
"gen.s.use_dada": True,
"gen.m.use_dada": True,
}, # 17
]
n_confs = len(test_scenarios)
fails = []
successes = []
# --------------------------------
# ----- Run Test Scenarios -----
# --------------------------------
for test_idx, conf in enumerate(test_scenarios):
if test_idx in exclude or (include and test_idx not in include):
reason = (
"because it is in exclude"
if test_idx in exclude
else "because it is not in include"
)
print("Ignoring test", test_idx, reason)
continue
# copy base scenario opts
test_opts = deepcopy(base_opts)
# update with scenario configuration
set_conf(test_opts, conf)
# print scenario description
print_start(
f"[{test_idx}/{n_confs - 1}] "
+ conf.get("__doc", "WARNING: no __doc for test scenario")
)
print()
comet = conf.get("__use_comet", True)
pl4m = conf.get("__pl4m", False)
inference = conf.get("__inference", False)
verbose = conf.get("__verbose", 0)
# set (or not) experiment
test_exp = None
if comet:
test_exp = global_exp
try:
# create trainer
trainer = Trainer(
opts=test_opts,
verbose=verbose,
comet_exp=test_exp,
)
trainer.functional_test_mode()
# set (or not) painter loss for masker (= end-to-end)
if pl4m:
trainer.use_pl4m = True
# test training procedure
trainer.setup(inference=inference)
if not inference:
trainer.train()
successes.append(test_idx)
ok = True
except Exception as e:
print(e)
print(traceback.format_exc())
fails.append(test_idx)
ok = False
finally:
print_end(ok=ok)
print_end(desc=" ----- Summary ----- ")
if len(fails) == 0:
print("•• All scenarios were successful")
else:
print(f"•• {len(successes)}/{len(test_scenarios)} successful tests")
print(f"•• Failed test indices: {', '.join(map(str, fails))}")