|
|
|
""" |
|
Train a model on a dataset. |
|
|
|
Usage: |
|
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16 |
|
""" |
|
|
|
import math |
|
import os |
|
import subprocess |
|
import time |
|
import warnings |
|
from copy import deepcopy |
|
from datetime import datetime, timedelta |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
import torch |
|
from torch import distributed as dist |
|
from torch import nn, optim |
|
|
|
from ultralytics.cfg import get_cfg, get_save_dir |
|
from ultralytics.data.utils import check_cls_dataset, check_det_dataset |
|
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights |
|
from ultralytics.utils import ( |
|
DEFAULT_CFG, |
|
LOGGER, |
|
RANK, |
|
TQDM, |
|
__version__, |
|
callbacks, |
|
clean_url, |
|
colorstr, |
|
emojis, |
|
yaml_save, |
|
) |
|
from ultralytics.utils.autobatch import check_train_batch_size |
|
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args |
|
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command |
|
from ultralytics.utils.files import get_latest_run |
|
from ultralytics.utils.torch_utils import ( |
|
EarlyStopping, |
|
ModelEMA, |
|
de_parallel, |
|
init_seeds, |
|
one_cycle, |
|
select_device, |
|
strip_optimizer, |
|
) |
|
|
|
|
|
class BaseTrainer: |
|
""" |
|
BaseTrainer. |
|
|
|
A base class for creating trainers. |
|
|
|
Attributes: |
|
args (SimpleNamespace): Configuration for the trainer. |
|
validator (BaseValidator): Validator instance. |
|
model (nn.Module): Model instance. |
|
callbacks (defaultdict): Dictionary of callbacks. |
|
save_dir (Path): Directory to save results. |
|
wdir (Path): Directory to save weights. |
|
last (Path): Path to the last checkpoint. |
|
best (Path): Path to the best checkpoint. |
|
save_period (int): Save checkpoint every x epochs (disabled if < 1). |
|
batch_size (int): Batch size for training. |
|
epochs (int): Number of epochs to train for. |
|
start_epoch (int): Starting epoch for training. |
|
device (torch.device): Device to use for training. |
|
amp (bool): Flag to enable AMP (Automatic Mixed Precision). |
|
scaler (amp.GradScaler): Gradient scaler for AMP. |
|
data (str): Path to data. |
|
trainset (torch.utils.data.Dataset): Training dataset. |
|
testset (torch.utils.data.Dataset): Testing dataset. |
|
ema (nn.Module): EMA (Exponential Moving Average) of the model. |
|
resume (bool): Resume training from a checkpoint. |
|
lf (nn.Module): Loss function. |
|
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. |
|
best_fitness (float): The best fitness value achieved. |
|
fitness (float): Current fitness value. |
|
loss (float): Current loss value. |
|
tloss (float): Total loss value. |
|
loss_names (list): List of loss names. |
|
csv (Path): Path to results CSV file. |
|
""" |
|
|
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): |
|
""" |
|
Initializes the BaseTrainer class. |
|
|
|
Args: |
|
cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. |
|
overrides (dict, optional): Configuration overrides. Defaults to None. |
|
""" |
|
self.args = get_cfg(cfg, overrides) |
|
self.check_resume(overrides) |
|
self.device = select_device(self.args.device, self.args.batch) |
|
self.validator = None |
|
self.metrics = None |
|
self.plots = {} |
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) |
|
|
|
|
|
self.save_dir = get_save_dir(self.args) |
|
self.args.name = self.save_dir.name |
|
self.wdir = self.save_dir / "weights" |
|
if RANK in (-1, 0): |
|
self.wdir.mkdir(parents=True, exist_ok=True) |
|
self.args.save_dir = str(self.save_dir) |
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) |
|
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" |
|
self.save_period = self.args.save_period |
|
|
|
self.batch_size = self.args.batch |
|
self.epochs = self.args.epochs |
|
self.start_epoch = 0 |
|
if RANK == -1: |
|
print_args(vars(self.args)) |
|
|
|
|
|
if self.device.type in ("cpu", "mps"): |
|
self.args.workers = 0 |
|
|
|
|
|
self.model = check_model_file_from_stem(self.args.model) |
|
try: |
|
if self.args.task == "classify": |
|
self.data = check_cls_dataset(self.args.data) |
|
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in ( |
|
"detect", |
|
"segment", |
|
"pose", |
|
"obb", |
|
): |
|
self.data = check_det_dataset(self.args.data) |
|
if "yaml_file" in self.data: |
|
self.args.data = self.data["yaml_file"] |
|
except Exception as e: |
|
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error β {e}")) from e |
|
|
|
self.trainset, self.testset = self.get_dataset(self.data) |
|
self.ema = None |
|
|
|
|
|
self.lf = None |
|
self.scheduler = None |
|
|
|
|
|
self.best_fitness = None |
|
self.fitness = None |
|
self.loss = None |
|
self.tloss = None |
|
self.loss_names = ["Loss"] |
|
self.csv = self.save_dir / "results.csv" |
|
self.plot_idx = [0, 1, 2] |
|
|
|
|
|
self.callbacks = _callbacks or callbacks.get_default_callbacks() |
|
if RANK in (-1, 0): |
|
callbacks.add_integration_callbacks(self) |
|
|
|
def add_callback(self, event: str, callback): |
|
"""Appends the given callback.""" |
|
self.callbacks[event].append(callback) |
|
|
|
def set_callback(self, event: str, callback): |
|
"""Overrides the existing callbacks with the given callback.""" |
|
self.callbacks[event] = [callback] |
|
|
|
def run_callbacks(self, event: str): |
|
"""Run all existing callbacks associated with a particular event.""" |
|
for callback in self.callbacks.get(event, []): |
|
callback(self) |
|
|
|
def train(self): |
|
"""Allow device='', device=None on Multi-GPU systems to default to device=0.""" |
|
if isinstance(self.args.device, str) and len(self.args.device): |
|
world_size = len(self.args.device.split(",")) |
|
elif isinstance(self.args.device, (tuple, list)): |
|
world_size = len(self.args.device) |
|
elif torch.cuda.is_available(): |
|
world_size = 1 |
|
else: |
|
world_size = 0 |
|
|
|
|
|
if world_size > 1 and "LOCAL_RANK" not in os.environ: |
|
|
|
if self.args.rect: |
|
LOGGER.warning("WARNING β οΈ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") |
|
self.args.rect = False |
|
if self.args.batch == -1: |
|
LOGGER.warning( |
|
"WARNING β οΈ 'batch=-1' for AutoBatch is incompatible with Multi-GPU training, setting " |
|
"default 'batch=16'" |
|
) |
|
self.args.batch = 16 |
|
|
|
|
|
cmd, file = generate_ddp_command(world_size, self) |
|
try: |
|
LOGGER.info(f'{colorstr("DDP:")} debug command {" ".join(cmd)}') |
|
subprocess.run(cmd, check=True) |
|
except Exception as e: |
|
raise e |
|
finally: |
|
ddp_cleanup(self, str(file)) |
|
|
|
else: |
|
self._do_train(world_size) |
|
|
|
def _setup_scheduler(self): |
|
"""Initialize training learning rate scheduler.""" |
|
if self.args.cos_lr: |
|
self.lf = one_cycle(1, self.args.lrf, self.epochs) |
|
else: |
|
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf |
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) |
|
|
|
def _setup_ddp(self, world_size): |
|
"""Initializes and sets the DistributedDataParallel parameters for training.""" |
|
torch.cuda.set_device(RANK) |
|
self.device = torch.device("cuda", RANK) |
|
|
|
os.environ["NCCL_BLOCKING_WAIT"] = "1" |
|
dist.init_process_group( |
|
backend="nccl" if dist.is_nccl_available() else "gloo", |
|
timeout=timedelta(seconds=10800), |
|
rank=RANK, |
|
world_size=world_size, |
|
) |
|
|
|
def _setup_train(self, world_size): |
|
"""Builds dataloaders and optimizer on correct rank process.""" |
|
|
|
|
|
self.run_callbacks("on_pretrain_routine_start") |
|
ckpt = self.setup_model() |
|
self.model = self.model.to(self.device) |
|
self.set_model_attributes() |
|
|
|
|
|
freeze_list = ( |
|
self.args.freeze |
|
if isinstance(self.args.freeze, list) |
|
else range(self.args.freeze) |
|
if isinstance(self.args.freeze, int) |
|
else [] |
|
) |
|
always_freeze_names = [".dfl"] |
|
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names |
|
for k, v in self.model.named_parameters(): |
|
|
|
if any(x in k for x in freeze_layer_names): |
|
LOGGER.info(f"Freezing layer '{k}'") |
|
v.requires_grad = False |
|
elif not v.requires_grad and v.dtype.is_floating_point: |
|
LOGGER.info( |
|
f"WARNING β οΈ setting 'requires_grad=True' for frozen layer '{k}'. " |
|
"See ultralytics.engine.trainer for customization of frozen layers." |
|
) |
|
v.requires_grad = True |
|
|
|
|
|
self.amp = torch.tensor(self.args.amp).to(self.device) |
|
if self.amp and RANK in (-1, 0): |
|
callbacks_backup = callbacks.default_callbacks.copy() |
|
self.amp = torch.tensor(check_amp(self.model), device=self.device) |
|
callbacks.default_callbacks = callbacks_backup |
|
if RANK > -1 and world_size > 1: |
|
dist.broadcast(self.amp, src=0) |
|
self.amp = bool(self.amp) |
|
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) |
|
if world_size > 1: |
|
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK]) |
|
|
|
|
|
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) |
|
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) |
|
self.stride = gs |
|
|
|
|
|
if self.batch_size == -1 and RANK == -1: |
|
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp) |
|
|
|
|
|
batch_size = self.batch_size // max(world_size, 1) |
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train") |
|
if RANK in (-1, 0): |
|
|
|
self.test_loader = self.get_dataloader( |
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" |
|
) |
|
self.validator = self.get_validator() |
|
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") |
|
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) |
|
self.ema = ModelEMA(self.model) |
|
if self.args.plots: |
|
self.plot_training_labels() |
|
|
|
|
|
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) |
|
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs |
|
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs |
|
self.optimizer = self.build_optimizer( |
|
model=self.model, |
|
name=self.args.optimizer, |
|
lr=self.args.lr0, |
|
momentum=self.args.momentum, |
|
decay=weight_decay, |
|
iterations=iterations, |
|
) |
|
|
|
self._setup_scheduler() |
|
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False |
|
self.resume_training(ckpt) |
|
self.scheduler.last_epoch = self.start_epoch - 1 |
|
self.run_callbacks("on_pretrain_routine_end") |
|
|
|
def _do_train(self, world_size=1): |
|
"""Train completed, evaluate and plot if specified by arguments.""" |
|
if world_size > 1: |
|
self._setup_ddp(world_size) |
|
self._setup_train(world_size) |
|
|
|
nb = len(self.train_loader) |
|
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 |
|
last_opt_step = -1 |
|
self.epoch_time = None |
|
self.epoch_time_start = time.time() |
|
self.train_time_start = time.time() |
|
self.run_callbacks("on_train_start") |
|
LOGGER.info( |
|
f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n' |
|
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n' |
|
f"Logging results to {colorstr('bold', self.save_dir)}\n" |
|
f'Starting training for ' + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...") |
|
) |
|
if self.args.close_mosaic: |
|
base_idx = (self.epochs - self.args.close_mosaic) * nb |
|
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) |
|
epoch = self.start_epoch |
|
while True: |
|
self.epoch = epoch |
|
self.run_callbacks("on_train_epoch_start") |
|
self.model.train() |
|
if RANK != -1: |
|
self.train_loader.sampler.set_epoch(epoch) |
|
pbar = enumerate(self.train_loader) |
|
|
|
if epoch == (self.epochs - self.args.close_mosaic): |
|
self._close_dataloader_mosaic() |
|
self.train_loader.reset() |
|
|
|
if RANK in (-1, 0): |
|
LOGGER.info(self.progress_string()) |
|
pbar = TQDM(enumerate(self.train_loader), total=nb) |
|
self.tloss = None |
|
self.optimizer.zero_grad() |
|
for i, batch in pbar: |
|
self.run_callbacks("on_train_batch_start") |
|
|
|
ni = i + nb * epoch |
|
if ni <= nw: |
|
xi = [0, nw] |
|
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) |
|
for j, x in enumerate(self.optimizer.param_groups): |
|
|
|
x["lr"] = np.interp( |
|
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] |
|
) |
|
if "momentum" in x: |
|
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) |
|
|
|
|
|
with torch.cuda.amp.autocast(self.amp): |
|
batch = self.preprocess_batch(batch) |
|
self.loss, self.loss_items = self.model(batch) |
|
if RANK != -1: |
|
self.loss *= world_size |
|
self.tloss = ( |
|
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items |
|
) |
|
|
|
|
|
self.scaler.scale(self.loss).backward() |
|
|
|
|
|
if ni - last_opt_step >= self.accumulate: |
|
self.optimizer_step() |
|
last_opt_step = ni |
|
|
|
|
|
if self.args.time: |
|
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) |
|
if RANK != -1: |
|
broadcast_list = [self.stop if RANK == 0 else None] |
|
dist.broadcast_object_list(broadcast_list, 0) |
|
self.stop = broadcast_list[0] |
|
if self.stop: |
|
break |
|
|
|
|
|
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" |
|
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1 |
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) |
|
if RANK in (-1, 0): |
|
pbar.set_description( |
|
("%11s" * 2 + "%11.4g" * (2 + loss_len)) |
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1]) |
|
) |
|
self.run_callbacks("on_batch_end") |
|
if self.args.plots and ni in self.plot_idx: |
|
self.plot_training_samples(batch, ni) |
|
|
|
self.run_callbacks("on_train_batch_end") |
|
|
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} |
|
self.run_callbacks("on_train_epoch_end") |
|
if RANK in (-1, 0): |
|
final_epoch = epoch + 1 == self.epochs |
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) |
|
|
|
|
|
if (self.args.val and (((epoch+1) % self.args.val_period == 0) or (self.epochs - epoch) <= 10)) \ |
|
or final_epoch or self.stopper.possible_stop or self.stop: |
|
self.metrics, self.fitness = self.validate() |
|
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) |
|
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch |
|
if self.args.time: |
|
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600) |
|
|
|
|
|
if self.args.save or final_epoch: |
|
self.save_model() |
|
self.run_callbacks("on_model_save") |
|
|
|
|
|
t = time.time() |
|
self.epoch_time = t - self.epoch_time_start |
|
self.epoch_time_start = t |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
if self.args.time: |
|
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) |
|
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) |
|
self._setup_scheduler() |
|
self.scheduler.last_epoch = self.epoch |
|
self.stop |= epoch >= self.epochs |
|
self.scheduler.step() |
|
self.run_callbacks("on_fit_epoch_end") |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if RANK != -1: |
|
broadcast_list = [self.stop if RANK == 0 else None] |
|
dist.broadcast_object_list(broadcast_list, 0) |
|
self.stop = broadcast_list[0] |
|
if self.stop: |
|
break |
|
epoch += 1 |
|
|
|
if RANK in (-1, 0): |
|
|
|
LOGGER.info( |
|
f"\n{epoch - self.start_epoch + 1} epochs completed in " |
|
f"{(time.time() - self.train_time_start) / 3600:.3f} hours." |
|
) |
|
self.final_eval() |
|
if self.args.plots: |
|
self.plot_metrics() |
|
self.run_callbacks("on_train_end") |
|
torch.cuda.empty_cache() |
|
self.run_callbacks("teardown") |
|
|
|
def save_model(self): |
|
"""Save model training checkpoints with additional metadata.""" |
|
import pandas as pd |
|
|
|
metrics = {**self.metrics, **{"fitness": self.fitness}} |
|
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} |
|
ckpt = { |
|
"epoch": self.epoch, |
|
"best_fitness": self.best_fitness, |
|
"model": deepcopy(de_parallel(self.model)).half(), |
|
"ema": deepcopy(self.ema.ema).half(), |
|
"updates": self.ema.updates, |
|
"optimizer": self.optimizer.state_dict(), |
|
"train_args": vars(self.args), |
|
"train_metrics": metrics, |
|
"train_results": results, |
|
"date": datetime.now().isoformat(), |
|
"version": __version__, |
|
"license": "AGPL-3.0 (https://ultralytics.com/license)", |
|
"docs": "https://docs.ultralytics.com", |
|
} |
|
|
|
|
|
torch.save(ckpt, self.last) |
|
if self.best_fitness == self.fitness: |
|
torch.save(ckpt, self.best) |
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): |
|
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt") |
|
|
|
@staticmethod |
|
def get_dataset(data): |
|
""" |
|
Get train, val path from data dict if it exists. |
|
|
|
Returns None if data format is not recognized. |
|
""" |
|
return data["train"], data.get("val") or data.get("test") |
|
|
|
def setup_model(self): |
|
"""Load/create/download model for any task.""" |
|
if isinstance(self.model, torch.nn.Module): |
|
return |
|
|
|
model, weights = self.model, None |
|
ckpt = None |
|
if str(model).endswith(".pt"): |
|
weights, ckpt = attempt_load_one_weight(model) |
|
cfg = ckpt["model"].yaml |
|
else: |
|
cfg = model |
|
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) |
|
return ckpt |
|
|
|
def optimizer_step(self): |
|
"""Perform a single step of the training optimizer with gradient clipping and EMA update.""" |
|
self.scaler.unscale_(self.optimizer) |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
self.optimizer.zero_grad() |
|
if self.ema: |
|
self.ema.update(self.model) |
|
|
|
def preprocess_batch(self, batch): |
|
"""Allows custom preprocessing model inputs and ground truths depending on task type.""" |
|
return batch |
|
|
|
def validate(self): |
|
""" |
|
Runs validation on test set using self.validator. |
|
|
|
The returned dict is expected to contain "fitness" key. |
|
""" |
|
metrics = self.validator(self) |
|
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) |
|
if not self.best_fitness or self.best_fitness < fitness: |
|
self.best_fitness = fitness |
|
return metrics, fitness |
|
|
|
def get_model(self, cfg=None, weights=None, verbose=True): |
|
"""Get model and raise NotImplementedError for loading cfg files.""" |
|
raise NotImplementedError("This task trainer doesn't support loading cfg files") |
|
|
|
def get_validator(self): |
|
"""Returns a NotImplementedError when the get_validator function is called.""" |
|
raise NotImplementedError("get_validator function not implemented in trainer") |
|
|
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): |
|
"""Returns dataloader derived from torch.data.Dataloader.""" |
|
raise NotImplementedError("get_dataloader function not implemented in trainer") |
|
|
|
def build_dataset(self, img_path, mode="train", batch=None): |
|
"""Build dataset.""" |
|
raise NotImplementedError("build_dataset function not implemented in trainer") |
|
|
|
def label_loss_items(self, loss_items=None, prefix="train"): |
|
""" |
|
Returns a loss dict with labelled training loss items tensor. |
|
|
|
Note: |
|
This is not needed for classification but necessary for segmentation & detection |
|
""" |
|
return {"loss": loss_items} if loss_items is not None else ["loss"] |
|
|
|
def set_model_attributes(self): |
|
"""To set or update model parameters before training.""" |
|
self.model.names = self.data["names"] |
|
|
|
def build_targets(self, preds, targets): |
|
"""Builds target tensors for training YOLO model.""" |
|
pass |
|
|
|
def progress_string(self): |
|
"""Returns a string describing training progress.""" |
|
return "" |
|
|
|
|
|
def plot_training_samples(self, batch, ni): |
|
"""Plots training samples during YOLO training.""" |
|
pass |
|
|
|
def plot_training_labels(self): |
|
"""Plots training labels for YOLO model.""" |
|
pass |
|
|
|
def save_metrics(self, metrics): |
|
"""Saves training metrics to a CSV file.""" |
|
keys, vals = list(metrics.keys()), list(metrics.values()) |
|
n = len(metrics) + 1 |
|
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") |
|
with open(self.csv, "a") as f: |
|
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n") |
|
|
|
def plot_metrics(self): |
|
"""Plot and display metrics visually.""" |
|
pass |
|
|
|
def on_plot(self, name, data=None): |
|
"""Registers plots (e.g. to be consumed in callbacks)""" |
|
path = Path(name) |
|
self.plots[path] = {"data": data, "timestamp": time.time()} |
|
|
|
def final_eval(self): |
|
"""Performs final evaluation and validation for object detection YOLO model.""" |
|
for f in self.last, self.best: |
|
if f.exists(): |
|
strip_optimizer(f) |
|
if f is self.best: |
|
LOGGER.info(f"\nValidating {f}...") |
|
self.validator.args.plots = self.args.plots |
|
self.metrics = self.validator(model=f) |
|
self.metrics.pop("fitness", None) |
|
self.run_callbacks("on_fit_epoch_end") |
|
|
|
def check_resume(self, overrides): |
|
"""Check if resume checkpoint exists and update arguments accordingly.""" |
|
resume = self.args.resume |
|
if resume: |
|
try: |
|
exists = isinstance(resume, (str, Path)) and Path(resume).exists() |
|
last = Path(check_file(resume) if exists else get_latest_run()) |
|
|
|
|
|
ckpt_args = attempt_load_weights(last).args |
|
if not Path(ckpt_args["data"]).exists(): |
|
ckpt_args["data"] = self.args.data |
|
|
|
resume = True |
|
self.args = get_cfg(ckpt_args) |
|
self.args.model = self.args.resume = str(last) |
|
for k in "imgsz", "batch", "device": |
|
if k in overrides: |
|
setattr(self.args, k, overrides[k]) |
|
|
|
except Exception as e: |
|
raise FileNotFoundError( |
|
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, " |
|
"i.e. 'yolo train resume model=path/to/last.pt'" |
|
) from e |
|
self.resume = resume |
|
|
|
def resume_training(self, ckpt): |
|
"""Resume YOLO training from given epoch and best fitness.""" |
|
if ckpt is None or not self.resume: |
|
return |
|
best_fitness = 0.0 |
|
start_epoch = ckpt["epoch"] + 1 |
|
if ckpt["optimizer"] is not None: |
|
self.optimizer.load_state_dict(ckpt["optimizer"]) |
|
best_fitness = ckpt["best_fitness"] |
|
if self.ema and ckpt.get("ema"): |
|
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) |
|
self.ema.updates = ckpt["updates"] |
|
assert start_epoch > 0, ( |
|
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" |
|
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" |
|
) |
|
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") |
|
if self.epochs < start_epoch: |
|
LOGGER.info( |
|
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." |
|
) |
|
self.epochs += ckpt["epoch"] |
|
self.best_fitness = best_fitness |
|
self.start_epoch = start_epoch |
|
if start_epoch > (self.epochs - self.args.close_mosaic): |
|
self._close_dataloader_mosaic() |
|
|
|
def _close_dataloader_mosaic(self): |
|
"""Update dataloaders to stop using mosaic augmentation.""" |
|
if hasattr(self.train_loader.dataset, "mosaic"): |
|
self.train_loader.dataset.mosaic = False |
|
if hasattr(self.train_loader.dataset, "close_mosaic"): |
|
LOGGER.info("Closing dataloader mosaic") |
|
self.train_loader.dataset.close_mosaic(hyp=self.args) |
|
|
|
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): |
|
""" |
|
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, |
|
weight decay, and number of iterations. |
|
|
|
Args: |
|
model (torch.nn.Module): The model for which to build an optimizer. |
|
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected |
|
based on the number of iterations. Default: 'auto'. |
|
lr (float, optional): The learning rate for the optimizer. Default: 0.001. |
|
momentum (float, optional): The momentum factor for the optimizer. Default: 0.9. |
|
decay (float, optional): The weight decay for the optimizer. Default: 1e-5. |
|
iterations (float, optional): The number of iterations, which determines the optimizer if |
|
name is 'auto'. Default: 1e5. |
|
|
|
Returns: |
|
(torch.optim.Optimizer): The constructed optimizer. |
|
""" |
|
|
|
g = [], [], [] |
|
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) |
|
if name == "auto": |
|
LOGGER.info( |
|
f"{colorstr('optimizer:')} 'optimizer=auto' found, " |
|
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " |
|
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " |
|
) |
|
nc = getattr(model, "nc", 10) |
|
lr_fit = round(0.002 * 5 / (4 + nc), 6) |
|
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) |
|
self.args.warmup_bias_lr = 0.0 |
|
|
|
for module_name, module in model.named_modules(): |
|
for param_name, param in module.named_parameters(recurse=False): |
|
fullname = f"{module_name}.{param_name}" if module_name else param_name |
|
if "bias" in fullname: |
|
g[2].append(param) |
|
elif isinstance(module, bn): |
|
g[1].append(param) |
|
else: |
|
g[0].append(param) |
|
|
|
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"): |
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) |
|
elif name == "RMSProp": |
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) |
|
elif name == "SGD": |
|
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) |
|
else: |
|
raise NotImplementedError( |
|
f"Optimizer '{name}' not found in list of available optimizers " |
|
f"[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto]." |
|
"To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics." |
|
) |
|
|
|
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) |
|
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) |
|
LOGGER.info( |
|
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " |
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' |
|
) |
|
return optimizer |
|
|