#include "task_group_priority.h" void task_group_priority_cpu( int* group, int* priority, bool* value, bool* output, int batch_size, int task_num, int group_num) { auto temp = torch::make_unique(group_num); for(int b=0; b::max(); } for(int i=0; i torch::Tensor { auto device = group.device(); const int batch_size = group.size(0); const int task_num = group.size(1); const int group_num = group.max().item() + 1; const int _group_num = group.min().item(); GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error"); GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num); GRL_CHECK_TENSOR(priority, device, false, false, batch_size, task_num); GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num); auto output = torch::zeros({batch_size, task_num}, torch::dtype(torch::kBool).device(device)); switch(device.type()) { case torch::kCPU: task_group_priority_cpu(group.data_ptr(), priority.data_ptr(), value.data_ptr(), output.data_ptr(), batch_size, task_num, group_num); break; #ifdef CUDA_FOUND case torch::kCUDA: task_group_priority_cuda(group.data_ptr(), priority.data_ptr(), value.data_ptr(), output.data_ptr(), batch_size, task_num, group_num, device.index()); break; #endif default: GRL_ERROR("unsupported device: %s", device.str().c_str()); } return output; };