|
import json |
|
import logging |
|
import time |
|
from dataclasses import KW_ONLY, dataclass |
|
from pathlib import Path |
|
from typing import Protocol |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.utils.data import DataLoader |
|
|
|
from .control import non_blocking_input |
|
from .distributed import is_global_leader |
|
from .engine import Engine |
|
from .utils import tree_map |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class EvalFn(Protocol): |
|
def __call__(self, engine: Engine, eval_dir: Path) -> None: |
|
... |
|
|
|
|
|
class EngineLoader(Protocol): |
|
def __call__(self, run_dir: Path) -> Engine: |
|
... |
|
|
|
|
|
class GenFeeder(Protocol): |
|
def __call__(self, engine: Engine, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: |
|
... |
|
|
|
|
|
class DisFeeder(Protocol): |
|
def __call__(self, engine: Engine, batch: dict[str, Tensor] | None, fake: Tensor) -> dict[str, Tensor]: |
|
... |
|
|
|
|
|
@dataclass |
|
class TrainLoop: |
|
_ = KW_ONLY |
|
|
|
run_dir: Path |
|
train_dl: DataLoader |
|
|
|
load_G: EngineLoader |
|
feed_G: GenFeeder |
|
load_D: EngineLoader | None = None |
|
feed_D: DisFeeder | None = None |
|
|
|
update_every: int = 5_000 |
|
eval_every: int = 5_000 |
|
backup_steps: tuple[int, ...] = (5_000, 100_000, 500_000) |
|
|
|
device: str = "cuda" |
|
eval_fn: EvalFn | None = None |
|
gan_training_start_step: int | None = None |
|
|
|
@property |
|
def global_step(self): |
|
return self.engine_G.global_step |
|
|
|
@property |
|
def eval_dir(self) -> Path | None: |
|
if self.eval_every != 0: |
|
eval_dir = self.run_dir.joinpath("eval") |
|
eval_dir.mkdir(exist_ok=True) |
|
else: |
|
eval_dir = None |
|
return eval_dir |
|
|
|
@property |
|
def viz_dir(self) -> Path: |
|
return Path(self.run_dir / "viz") |
|
|
|
def make_current_step_viz_path(self, name: str, suffix: str) -> Path: |
|
path = (self.viz_dir / name / f"{self.global_step}").with_suffix(suffix) |
|
path.parent.mkdir(exist_ok=True, parents=True) |
|
return path |
|
|
|
def __post_init__(self): |
|
engine_G = self.load_G(self.run_dir) |
|
if self.load_D is None: |
|
engine_D = None |
|
else: |
|
engine_D = self.load_D(self.run_dir) |
|
self.engine_G = engine_G |
|
self.engine_D = engine_D |
|
|
|
@property |
|
def model_G(self): |
|
return self.engine_G.module |
|
|
|
@property |
|
def model_D(self): |
|
if self.engine_D is None: |
|
return None |
|
return self.engine_D.module |
|
|
|
def save_checkpoint(self, tag="default"): |
|
engine_G = self.engine_G |
|
engine_D = self.engine_D |
|
engine_G.save_checkpoint(tag=tag) |
|
if engine_D is not None: |
|
engine_D.save_checkpoint(tag=tag) |
|
|
|
def run(self, max_steps: int = -1): |
|
self.set_running_loop_(self) |
|
|
|
train_dl = self.train_dl |
|
update_every = self.update_every |
|
eval_every = self.eval_every |
|
device = self.device |
|
eval_fn = self.eval_fn |
|
|
|
engine_G = self.engine_G |
|
engine_D = self.engine_D |
|
eval_dir = self.eval_dir |
|
|
|
init_step = self.global_step |
|
|
|
logger.info(f"\nTraining from step {init_step} to step {max_steps}") |
|
warmup_steps = {init_step + x for x in [50, 100, 500]} |
|
|
|
engine_G.train() |
|
|
|
if engine_D is not None: |
|
engine_D.train() |
|
|
|
gan_start_step = self.gan_training_start_step |
|
|
|
while True: |
|
loss_G = loss_D = 0 |
|
|
|
for batch in train_dl: |
|
torch.cuda.synchronize() |
|
start_time = time.time() |
|
|
|
|
|
step = self.global_step + 1 |
|
|
|
|
|
batch = tree_map(lambda x: x.to(device) if isinstance(x, Tensor) else x, batch) |
|
|
|
stats = {"step": step} |
|
|
|
|
|
gan_started = gan_start_step is not None and (step >= gan_start_step or step == 1) |
|
gan_started &= engine_D is not None |
|
|
|
|
|
fake, losses = self.feed_G(engine=engine_G, batch=batch) |
|
|
|
|
|
if gan_started: |
|
assert engine_D is not None |
|
assert self.feed_D is not None |
|
|
|
|
|
engine_D.freeze_() |
|
losses |= self.feed_D(engine=engine_D, batch=None, fake=fake) |
|
|
|
loss_G = sum(losses.values()) |
|
stats |= {f"G/{k}": v.item() for k, v in losses.items()} |
|
stats |= {f"G/{k}": v for k, v in engine_G.gather_attribute("stats").items()} |
|
del losses |
|
|
|
assert isinstance(loss_G, Tensor) |
|
stats["G/loss"] = loss_G.item() |
|
stats["G/lr"] = engine_G.get_lr()[0] |
|
stats["G/grad_norm"] = engine_G.get_grad_norm() or 0 |
|
|
|
if loss_G.isnan().item(): |
|
logger.error("Generator loss is NaN, skipping step") |
|
continue |
|
|
|
engine_G.backward(loss_G) |
|
engine_G.step() |
|
|
|
|
|
if gan_started: |
|
assert engine_D is not None |
|
assert self.feed_D is not None |
|
|
|
engine_D.unfreeze_() |
|
losses = self.feed_D(engine=engine_D, batch=batch, fake=fake.detach()) |
|
del fake |
|
|
|
assert isinstance(losses, dict) |
|
loss_D = sum(losses.values()) |
|
assert isinstance(loss_D, Tensor) |
|
|
|
stats |= {f"D/{k}": v.item() for k, v in losses.items()} |
|
stats |= {f"D/{k}": v for k, v in engine_D.gather_attribute("stats").items()} |
|
del losses |
|
|
|
if loss_D.isnan().item(): |
|
logger.error("Discriminator loss is NaN, skipping step") |
|
continue |
|
|
|
engine_D.backward(loss_D) |
|
engine_D.step() |
|
|
|
stats["D/loss"] = loss_D.item() |
|
stats["D/lr"] = engine_D.get_lr()[0] |
|
stats["D/grad_norm"] = engine_D.get_grad_norm() or 0 |
|
|
|
torch.cuda.synchronize() |
|
stats["elapsed_time"] = time.time() - start_time |
|
stats = tree_map(lambda x: float(f"{x:.4g}") if isinstance(x, float) else x, stats) |
|
logger.info(json.dumps(stats, indent=0)) |
|
|
|
command = non_blocking_input() |
|
|
|
evaling = step % eval_every == 0 or step in warmup_steps or command.strip() == "eval" |
|
if eval_fn is not None and is_global_leader() and eval_dir is not None and evaling: |
|
engine_G.eval() |
|
eval_fn(engine_G, eval_dir=eval_dir) |
|
engine_G.train() |
|
|
|
if command.strip() == "quit": |
|
logger.info("Training paused") |
|
self.save_checkpoint("default") |
|
return |
|
|
|
if command.strip() == "backup" or step in self.backup_steps: |
|
logger.info("Backing up") |
|
self.save_checkpoint(tag=f"backup_{step:07d}") |
|
|
|
if step % update_every == 0 or command.strip() == "save": |
|
self.save_checkpoint(tag="default") |
|
|
|
if step == max_steps: |
|
logger.info("Training finished") |
|
self.save_checkpoint(tag="default") |
|
return |
|
|
|
@classmethod |
|
def set_running_loop_(cls, loop): |
|
assert isinstance(loop, cls), f"Expected {cls}, got {type(loop)}" |
|
cls._running_loop: cls = loop |
|
|
|
@classmethod |
|
def get_running_loop(cls) -> "TrainLoop | None": |
|
if hasattr(cls, "_running_loop"): |
|
assert isinstance(cls._running_loop, cls) |
|
return cls._running_loop |
|
return None |
|
|
|
@classmethod |
|
def get_running_loop_global_step(cls) -> int | None: |
|
if loop := cls.get_running_loop(): |
|
return loop.global_step |
|
return None |
|
|
|
@classmethod |
|
def get_running_loop_viz_path(cls, name: str, suffix: str) -> Path | None: |
|
if loop := cls.get_running_loop(): |
|
return loop.make_current_step_viz_path(name, suffix) |
|
return None |
|
|