from typing import TYPE_CHECKING, Callable, Union from easydict import EasyDict import treetensor.torch as ttorch from ditk import logging import numpy as np from ding.policy import Policy from ding.framework import task, OfflineRLContext, OnlineRLContext def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable: """ Overview: The middleware that executes a single training process. Arguments: - cfg (:obj:`EasyDict`): Config. - policy (:obj:`Policy`): The policy to be trained in step-by-step mode. - log_freq (:obj:`int`): The frequency (iteration) of showing log. """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): """ Input of ctx: - train_data (:obj:`Dict`): The data used to update the network. It will train only if \ the data is not empty. - train_iter: (:obj:`int`): The training iteration count. The log will be printed once \ it reachs certain values. Output of ctx: - train_output (:obj:`Dict`): The training output in the Dict format, including loss info. """ if ctx.train_data is None: return train_output = policy.forward(ctx.train_data) if ctx.train_iter % log_freq == 0: if isinstance(train_output, list): train_output_loss = np.mean([item['total_loss'] for item in train_output]) else: train_output_loss = train_output['total_loss'] if isinstance(ctx, OnlineRLContext): logging.info( 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format( ctx.train_iter, ctx.env_step, train_output_loss ) ) elif isinstance(ctx, OfflineRLContext): logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output_loss)) else: raise TypeError("not supported ctx type: {}".format(type(ctx))) ctx.train_iter += 1 ctx.train_output = train_output return _train def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable: """ Overview: The middleware that executes training for a target num of steps. Arguments: - policy (:obj:`Policy`): The policy specialized for multi-step training. - log_freq (:obj:`int`): The frequency (iteration) of showing log. """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() last_log_iter = -1 def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]): """ Input of ctx: - train_data: The data used to update the network. It will train only if the data is not empty. - train_iter: (:obj:`int`): The training iteration count. The log will be printed if it reachs certain values. Output of ctx: - train_output (:obj:`List[Dict]`): The training output listed by steps. """ if ctx.train_data is None: # no enough data from data fetcher return if hasattr(policy, "_device"): # For ppof policy data = ctx.train_data.to(policy._device) elif hasattr(policy, "get_attribute"): # For other policy data = ctx.train_data.to(policy.get_attribute("device")) else: assert AttributeError("Policy should have attribution '_device'.") train_output = policy.forward(data) nonlocal last_log_iter if ctx.train_iter - last_log_iter >= log_freq: loss = np.mean([o['total_loss'] for o in train_output]) if isinstance(ctx, OfflineRLContext): logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, loss)) else: logging.info( 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(ctx.train_iter, ctx.env_step, loss) ) last_log_iter = ctx.train_iter ctx.train_iter += len(train_output) ctx.train_output = train_output return _train # TODO reward model