|
from typing import TYPE_CHECKING, Callable |
|
from ding.framework import task |
|
if TYPE_CHECKING: |
|
from ding.framework import OnlineRLContext |
|
|
|
|
|
def priority_calculator(priority_calculation_fn: Callable) -> Callable: |
|
""" |
|
Overview: |
|
The middleware that calculates the priority of the collected data. |
|
Arguments: |
|
- priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data. |
|
""" |
|
|
|
if task.router.is_active and not task.has_role(task.role.COLLECTOR): |
|
return task.void() |
|
|
|
def _priority_calculator(ctx: "OnlineRLContext") -> None: |
|
|
|
priority = priority_calculation_fn(ctx.trajectories) |
|
for i in range(len(priority)): |
|
ctx.trajectories[i]['priority'] = priority[i] |
|
|
|
return _priority_calculator |
|
|