File size: 1,006 Bytes
db26c81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
#pragma once
#include "./common.h"
/**
* tasks are divided into groups,
* tasks in a group are visited by it's priority.
* the min priority value of unvisited tasks in a group is computed,
* output is false, if the task's priority equal the computed min priority, otherwise output is true
*
* group: task's group, shape is (batch_size, task_num)
* priority: task's priority, shape is (batch_size, task_num)
* value: task is visited or not, shape is (batch_size, task_num)
*
* output: the result, shape is (batch_size, task_num)
*/
auto task_group_priority(
const torch::Tensor& group,
const torch::Tensor& priority,
const torch::Tensor& value) -> torch::Tensor;
void task_group_priority_cpu(
int* group, int* priority, bool* value, bool* ouput,
int batch_size, int task_num, int group_num);
void task_group_priority_cuda(
int* group, int* priority, bool* value, bool* ouput,
int batch_size, int task_num, int group_num, int device);
|