File size: 2,382 Bytes
db26c81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
#include "task_group_priority.h"
void task_group_priority_cpu(
int* group, int* priority, bool* value, bool* output,
int batch_size, int task_num, int group_num)
{
auto temp = torch::make_unique<int[]>(group_num);
for(int b=0; b<batch_size; b++)
{
for(int i=0; i<group_num; i++){
temp[i] = std::numeric_limits<int>::max();
}
for(int i=0; i<task_num; i++){
if(value[i]){
continue;
}
int g = group[i];
int p = priority[i];
if(p < temp[g]){
temp[g] = p;
}
}
for(int i=0; i<task_num; i++){
int g = group[i];
output[i] = priority[i]!=temp[g];
}
group += task_num;
priority += task_num;
value += task_num;
output += task_num;
}
};
auto task_group_priority(
const torch::Tensor& group,
const torch::Tensor& priority,
const torch::Tensor& value) -> torch::Tensor
{
auto device = group.device();
const int batch_size = group.size(0);
const int task_num = group.size(1);
const int group_num = group.max().item<int>() + 1;
const int _group_num = group.min().item<int>();
GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error");
GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num);
GRL_CHECK_TENSOR(priority, device, false, false, batch_size, task_num);
GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num);
auto output = torch::zeros({batch_size, task_num}, torch::dtype(torch::kBool).device(device));
switch(device.type())
{
case torch::kCPU:
task_group_priority_cpu(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
output.data_ptr<bool>(), batch_size, task_num, group_num);
break;
#ifdef CUDA_FOUND
case torch::kCUDA:
task_group_priority_cuda(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
output.data_ptr<bool>(), batch_size, task_num, group_num, device.index());
break;
#endif
default:
GRL_ERROR("unsupported device: %s", device.str().c_str());
}
return output;
};
|