File size: 2,069 Bytes
db26c81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
#include "task_group_split.h"
void task_group_split_cpu(
int* group, bool* value, bool* output,
const int batch_size, const int task_num, const int group_num)
{
auto temp = torch::make_unique<bool[]>(group_num);
for(int b=0; b<batch_size; b++)
{
for(int i=0; i<group_num; i++){
temp[i] = false;
}
for(int i=0; i<task_num; i++){
if(value[i]){
int g = group[i];
temp[g] = true;
}
}
output[b] = false;
for(int i=0; i<task_num; i++){
int g = group[i];
if(temp[g] && !value[i]){
output[b] = true;
break;
}
}
group += task_num;
value += task_num;
}
};
auto task_group_split(
const Tensor& group, const Tensor& value) -> Tensor
{
auto device = group.device();
const int batch_size = group.size(0);
const int task_num = group.size(1);
const int group_num = group.max().item<int>() + 1;
const int _group_num = group.min().item<int>();
GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error");
GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num);
GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num);
auto output = torch::zeros({batch_size}, torch::dtype(torch::kBool).device(device));
switch(device.type())
{
case torch::kCPU:
task_group_split_cpu(group.data_ptr<int>(), value.data_ptr<bool>(),
output.data_ptr<bool>(), batch_size, task_num, group_num);
break;
#ifdef CUDA_FOUND
case torch::kCUDA:
task_group_split_cuda(group.data_ptr<int>(), value.data_ptr<bool>(),
output.data_ptr<bool>(), batch_size, task_num, group_num, device.index());
break;
#endif
default:
GRL_ERROR("unsupported device: %s", device.str().c_str());
}
return output;
};
|