File size: 4,343 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import TYPE_CHECKING, Callable, Union
from easydict import EasyDict
import treetensor.torch as ttorch
from ditk import logging
import numpy as np
from ding.policy import Policy
from ding.framework import task, OfflineRLContext, OnlineRLContext


def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable:
    """
    Overview:
        The middleware that executes a single training process.
    Arguments:
        - cfg (:obj:`EasyDict`): Config.
        - policy (:obj:`Policy`): The policy to be trained in step-by-step mode.
        - log_freq (:obj:`int`): The frequency (iteration) of showing log.
    """
    if task.router.is_active and not task.has_role(task.role.LEARNER):
        return task.void()

    def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
        """
        Input of ctx:
            - train_data (:obj:`Dict`): The data used to update the network. It will train only if \
                the data is not empty.
            - train_iter: (:obj:`int`): The training iteration count. The log will be printed once \
                it reachs certain values.
        Output of ctx:
            - train_output (:obj:`Dict`): The training output in the Dict format, including loss info.
        """

        if ctx.train_data is None:
            return
        train_output = policy.forward(ctx.train_data)
        if ctx.train_iter % log_freq == 0:
            if isinstance(train_output, list):
                train_output_loss = np.mean([item['total_loss'] for item in train_output])
            else:
                train_output_loss = train_output['total_loss']
            if isinstance(ctx, OnlineRLContext):
                logging.info(
                    'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(
                        ctx.train_iter, ctx.env_step, train_output_loss
                    )
                )
            elif isinstance(ctx, OfflineRLContext):
                logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output_loss))
            else:
                raise TypeError("not supported ctx type: {}".format(type(ctx)))
        ctx.train_iter += 1
        ctx.train_output = train_output

    return _train


def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable:
    """
    Overview:
        The middleware that executes training for a target num of steps.
    Arguments:
        - policy (:obj:`Policy`): The policy specialized for multi-step training.
        - log_freq (:obj:`int`): The frequency (iteration) of showing log.
    """
    if task.router.is_active and not task.has_role(task.role.LEARNER):
        return task.void()
    last_log_iter = -1

    def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
        """
        Input of ctx:
            - train_data: The data used to update the network.
                It will train only if the data is not empty.
            - train_iter: (:obj:`int`): The training iteration count.
                The log will be printed if it reachs certain values.
        Output of ctx:
            - train_output (:obj:`List[Dict]`): The training output listed by steps.
        """

        if ctx.train_data is None:  # no enough data from data fetcher
            return
        if hasattr(policy, "_device"):  # For ppof policy
            data = ctx.train_data.to(policy._device)
        elif hasattr(policy, "get_attribute"):  # For other policy
            data = ctx.train_data.to(policy.get_attribute("device"))
        else:
            assert AttributeError("Policy should have attribution '_device'.")
        train_output = policy.forward(data)
        nonlocal last_log_iter
        if ctx.train_iter - last_log_iter >= log_freq:
            loss = np.mean([o['total_loss'] for o in train_output])
            if isinstance(ctx, OfflineRLContext):
                logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, loss))
            else:
                logging.info(
                    'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(ctx.train_iter, ctx.env_step, loss)
                )
            last_log_iter = ctx.train_iter
        ctx.train_iter += len(train_output)
        ctx.train_output = train_output

    return _train


# TODO reward model