|
from typing import TYPE_CHECKING, Callable, List, Tuple, Union, Dict, Optional |
|
from easydict import EasyDict |
|
from collections import deque |
|
|
|
from ding.framework import task |
|
from ding.data import Buffer |
|
from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer |
|
|
|
if TYPE_CHECKING: |
|
from ding.framework import Context, OnlineRLContext |
|
from ding.policy import Policy |
|
from ding.reward_model import BaseRewardModel |
|
|
|
|
|
class OffPolicyLearner: |
|
""" |
|
Overview: |
|
The class of the off-policy learner, including data fetching and model training. Use \ |
|
the `__call__` method to execute the whole learning process. |
|
""" |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if task.router.is_active and not task.has_role(task.role.LEARNER): |
|
return task.void() |
|
return super(OffPolicyLearner, cls).__new__(cls) |
|
|
|
def __init__( |
|
self, |
|
cfg: EasyDict, |
|
policy: 'Policy', |
|
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], |
|
reward_model: Optional['BaseRewardModel'] = None, |
|
log_freq: int = 100, |
|
) -> None: |
|
""" |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): Config. |
|
- policy (:obj:`Policy`): The policy to be trained. |
|
- buffer (:obj:`Buffer`): The replay buffer to store the data for training. |
|
- reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \ |
|
default to None. |
|
- log_freq (:obj:`int`): The frequency (iteration) of showing log. |
|
""" |
|
self.cfg = cfg |
|
self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_)) |
|
self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq)) |
|
if reward_model is not None: |
|
self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model)) |
|
else: |
|
self._reward_estimator = None |
|
|
|
def __call__(self, ctx: "OnlineRLContext") -> None: |
|
""" |
|
Output of ctx: |
|
- train_output (:obj:`Deque`): The training output in deque. |
|
""" |
|
train_output_queue = [] |
|
for _ in range(self.cfg.policy.learn.update_per_collect): |
|
self._fetcher(ctx) |
|
if ctx.train_data is None: |
|
break |
|
if self._reward_estimator: |
|
self._reward_estimator(ctx) |
|
self._trainer(ctx) |
|
train_output_queue.append(ctx.train_output) |
|
ctx.train_output = train_output_queue |
|
|
|
|
|
class HERLearner: |
|
""" |
|
Overview: |
|
The class of the learner with the Hindsight Experience Replay (HER). \ |
|
Use the `__call__` method to execute the data featching and training \ |
|
process. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
cfg: EasyDict, |
|
policy, |
|
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], |
|
her_reward_model, |
|
) -> None: |
|
""" |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): Config. |
|
- policy (:obj:`Policy`): The policy to be trained. |
|
- buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training. |
|
- her_reward_model (:obj:`HerRewardModel`): HER reward model. |
|
""" |
|
self.cfg = cfg |
|
self._fetcher = task.wrap(her_data_enhancer(cfg, buffer_, her_reward_model)) |
|
self._trainer = task.wrap(trainer(cfg, policy)) |
|
|
|
def __call__(self, ctx: "OnlineRLContext") -> None: |
|
""" |
|
Output of ctx: |
|
- train_output (:obj:`Deque`): The deque of training output. |
|
""" |
|
train_output_queue = [] |
|
for _ in range(self.cfg.policy.learn.update_per_collect): |
|
self._fetcher(ctx) |
|
if ctx.train_data is None: |
|
break |
|
self._trainer(ctx) |
|
train_output_queue.append(ctx.train_output) |
|
ctx.train_output = train_output_queue |
|
|