File size: 2,382 Bytes
db26c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include "task_group_priority.h"

void task_group_priority_cpu(
        int* group, int* priority, bool* value, bool* output,
        int batch_size, int task_num, int group_num)
{
    auto temp = torch::make_unique<int[]>(group_num);
    for(int b=0; b<batch_size; b++)
    {
        for(int i=0; i<group_num; i++){
            temp[i] = std::numeric_limits<int>::max();
        }
        
        for(int i=0; i<task_num; i++){
            if(value[i]){
                continue;
            }
            int g = group[i];
            int p = priority[i];
            if(p < temp[g]){
                temp[g] = p;
            }
        }
    
        for(int i=0; i<task_num; i++){
            int g = group[i];
            output[i] = priority[i]!=temp[g];
        }
        
        group += task_num;
        priority += task_num;
        value += task_num;
        output += task_num;
    }
};

auto task_group_priority(
        const torch::Tensor& group,
        const torch::Tensor& priority,
        const torch::Tensor& value) -> torch::Tensor 
{
    auto device = group.device();

    const int batch_size = group.size(0);
    const int task_num = group.size(1);
    const int group_num = group.max().item<int>() + 1;

    const int _group_num = group.min().item<int>();

    GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error");

    GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num);
    GRL_CHECK_TENSOR(priority, device, false, false, batch_size, task_num);
    GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num);

    auto output = torch::zeros({batch_size, task_num}, torch::dtype(torch::kBool).device(device));

    switch(device.type())
    {
        case torch::kCPU:
            task_group_priority_cpu(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
                                    output.data_ptr<bool>(), batch_size, task_num, group_num); 
            break;
#ifdef CUDA_FOUND
        case torch::kCUDA:
            task_group_priority_cuda(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
                                    output.data_ptr<bool>(), batch_size, task_num, group_num, device.index());
            break;
#endif
        default:
            GRL_ERROR("unsupported device: %s", device.str().c_str());
    }
    
    return output;
};