File size: 808 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from typing import TYPE_CHECKING, Callable
from ding.framework import task
if TYPE_CHECKING:
from ding.framework import OnlineRLContext
def priority_calculator(priority_calculation_fn: Callable) -> Callable:
"""
Overview:
The middleware that calculates the priority of the collected data.
Arguments:
- priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data.
"""
if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()
def _priority_calculator(ctx: "OnlineRLContext") -> None:
priority = priority_calculation_fn(ctx.trajectories)
for i in range(len(priority)):
ctx.trajectories[i]['priority'] = priority[i]
return _priority_calculator
|