zjowowen's picture
init space
079c32c
raw
history blame
4.34 kB
from typing import TYPE_CHECKING, Callable, Union
from easydict import EasyDict
import treetensor.torch as ttorch
from ditk import logging
import numpy as np
from ding.policy import Policy
from ding.framework import task, OfflineRLContext, OnlineRLContext
def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable:
"""
Overview:
The middleware that executes a single training process.
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be trained in step-by-step mode.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
"""
Input of ctx:
- train_data (:obj:`Dict`): The data used to update the network. It will train only if \
the data is not empty.
- train_iter: (:obj:`int`): The training iteration count. The log will be printed once \
it reachs certain values.
Output of ctx:
- train_output (:obj:`Dict`): The training output in the Dict format, including loss info.
"""
if ctx.train_data is None:
return
train_output = policy.forward(ctx.train_data)
if ctx.train_iter % log_freq == 0:
if isinstance(train_output, list):
train_output_loss = np.mean([item['total_loss'] for item in train_output])
else:
train_output_loss = train_output['total_loss']
if isinstance(ctx, OnlineRLContext):
logging.info(
'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(
ctx.train_iter, ctx.env_step, train_output_loss
)
)
elif isinstance(ctx, OfflineRLContext):
logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output_loss))
else:
raise TypeError("not supported ctx type: {}".format(type(ctx)))
ctx.train_iter += 1
ctx.train_output = train_output
return _train
def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable:
"""
Overview:
The middleware that executes training for a target num of steps.
Arguments:
- policy (:obj:`Policy`): The policy specialized for multi-step training.
- log_freq (:obj:`int`): The frequency (iteration) of showing log.
"""
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
last_log_iter = -1
def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
"""
Input of ctx:
- train_data: The data used to update the network.
It will train only if the data is not empty.
- train_iter: (:obj:`int`): The training iteration count.
The log will be printed if it reachs certain values.
Output of ctx:
- train_output (:obj:`List[Dict]`): The training output listed by steps.
"""
if ctx.train_data is None: # no enough data from data fetcher
return
if hasattr(policy, "_device"): # For ppof policy
data = ctx.train_data.to(policy._device)
elif hasattr(policy, "get_attribute"): # For other policy
data = ctx.train_data.to(policy.get_attribute("device"))
else:
assert AttributeError("Policy should have attribution '_device'.")
train_output = policy.forward(data)
nonlocal last_log_iter
if ctx.train_iter - last_log_iter >= log_freq:
loss = np.mean([o['total_loss'] for o in train_output])
if isinstance(ctx, OfflineRLContext):
logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, loss))
else:
logging.info(
'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(ctx.train_iter, ctx.env_step, loss)
)
last_log_iter = ctx.train_iter
ctx.train_iter += len(train_output)
ctx.train_output = train_output
return _train
# TODO reward model