File size: 808 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from typing import TYPE_CHECKING, Callable
from ding.framework import task
if TYPE_CHECKING:
    from ding.framework import OnlineRLContext


def priority_calculator(priority_calculation_fn: Callable) -> Callable:
    """
    Overview:
        The middleware that calculates the priority of the collected data.
    Arguments:
        - priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data.
    """

    if task.router.is_active and not task.has_role(task.role.COLLECTOR):
        return task.void()

    def _priority_calculator(ctx: "OnlineRLContext") -> None:

        priority = priority_calculation_fn(ctx.trajectories)
        for i in range(len(priority)):
            ctx.trajectories[i]['priority'] = priority[i]

    return _priority_calculator