File size: 5,472 Bytes
e8ffc70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <torch/serialize/tensor.h>
#include <vector>
// #include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"

// extern THCState *state;

#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#define CHECK_CUDA(x) do { \
	  if (!x.type().is_cuda()) { \
		      fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
		      exit(-1); \
		    } \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
	  if (!x.is_contiguous()) { \
		      fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
		      exit(-1); \
		    } \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)


void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 
    at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
    const float *unknown = unknown_tensor.data<float>();
    const float *known = known_tensor.data<float>();
    float *dist2 = dist2_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx);
}


void three_interpolate_wrapper_fast(int b, int c, int m, int n,
                         at::Tensor points_tensor,
                         at::Tensor idx_tensor,
                         at::Tensor weight_tensor,
                         at::Tensor out_tensor) {

    const float *points = points_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    float *out = out_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();


    three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out);
}

void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
                            at::Tensor grad_out_tensor,
                            at::Tensor idx_tensor,
                            at::Tensor weight_tensor,
                            at::Tensor grad_points_tensor) {

    const float *grad_out = grad_out_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    float *grad_points = grad_points_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();

    three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points);
}


void three_nn_wrapper_stack(at::Tensor unknown_tensor,
    at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor,
    at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor){
    // unknown: (N1 + N2 ..., 3)
    // unknown_batch_cnt: (batch_size), [N1, N2, ...]
    // known: (M1 + M2 ..., 3)
    // known_batch_cnt: (batch_size), [M1, M2, ...]
    // Return:
    // dist: (N1 + N2 ..., 3)  l2 distance to the three nearest neighbors
    // idx: (N1 + N2 ..., 3)  index of the three nearest neighbors
    CHECK_INPUT(unknown_tensor);
    CHECK_INPUT(unknown_batch_cnt_tensor);
    CHECK_INPUT(known_tensor);
    CHECK_INPUT(known_batch_cnt_tensor);
    CHECK_INPUT(dist2_tensor);
    CHECK_INPUT(idx_tensor);

    int batch_size = unknown_batch_cnt_tensor.size(0);
    int N = unknown_tensor.size(0);
    int M = known_tensor.size(0);
    const float *unknown = unknown_tensor.data<float>();
    const int *unknown_batch_cnt = unknown_batch_cnt_tensor.data<int>();
    const float *known = known_tensor.data<float>();
    const int *known_batch_cnt = known_batch_cnt_tensor.data<int>();
    float *dist2 = dist2_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    three_nn_kernel_launcher_stack(batch_size, N, M, unknown, unknown_batch_cnt, known, known_batch_cnt, dist2, idx);
}


void three_interpolate_wrapper_stack(at::Tensor features_tensor,
    at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor) {
    // features_tensor: (M1 + M2 ..., C)
    // idx_tensor: [N1 + N2 ..., 3]
    // weight_tensor: [N1 + N2 ..., 3]
    // Return:
    // out_tensor: (N1 + N2 ..., C)
    CHECK_INPUT(features_tensor);
    CHECK_INPUT(idx_tensor);
    CHECK_INPUT(weight_tensor);
    CHECK_INPUT(out_tensor);

    int N = out_tensor.size(0);
    int channels = features_tensor.size(1);
    const float *features = features_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *out = out_tensor.data<float>();

    three_interpolate_kernel_launcher_stack(N, channels, features, idx, weight, out);
}


void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor,
    at::Tensor weight_tensor, at::Tensor grad_features_tensor) {
    // grad_out_tensor: (N1 + N2 ..., C)
    // idx_tensor: [N1 + N2 ..., 3]
    // weight_tensor: [N1 + N2 ..., 3]
    // Return:
    // grad_features_tensor: (M1 + M2 ..., C)
    CHECK_INPUT(grad_out_tensor);
    CHECK_INPUT(idx_tensor);
    CHECK_INPUT(weight_tensor);
    CHECK_INPUT(grad_features_tensor);

    int N = grad_out_tensor.size(0);
    int channels = grad_out_tensor.size(1);
    const float *grad_out = grad_out_tensor.data<float>();
    const float *weight = weight_tensor.data<float>();
    const int *idx = idx_tensor.data<int>();
    float *grad_features = grad_features_tensor.data<float>();

    // printf("N=%d, channels=%d\n", N, channels);
    three_interpolate_grad_kernel_launcher_stack(N, channels, grad_out, idx, weight, grad_features);
}