File size: 1,006 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
#pragma once

#include "./common.h"

/**
 * tasks are divided into groups,
 * tasks in a group are visited by it's priority.
 * the min priority value of unvisited tasks in a group is computed,
 * output is false, if the task's priority equal the computed min priority, otherwise output is true
 *
 * group: task's group, shape is (batch_size, task_num)
 * priority: task's priority, shape is (batch_size, task_num)
 * value: task is visited or not, shape is (batch_size, task_num)
 *
 * output: the result, shape is (batch_size, task_num)
 */
auto task_group_priority(
        const torch::Tensor& group,
        const torch::Tensor& priority,
        const torch::Tensor& value) -> torch::Tensor;

void task_group_priority_cpu(
        int* group, int* priority, bool* value, bool* ouput,
        int batch_size, int task_num, int group_num);

void task_group_priority_cuda(
        int* group, int* priority, bool* value, bool* ouput,
        int batch_size, int task_num, int group_num, int device);