File size: 5,472 Bytes
e8ffc70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#include <torch/serialize/tensor.h>
#include <vector>
// #include <THC/THC.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "interpolate_gpu.h"
// extern THCState *state;
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAEvent.h>
// cudaStream_t stream = at::cuda::getCurrentCUDAStream();
#define CHECK_CUDA(x) do { \
if (!x.type().is_cuda()) { \
fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_CONTIGUOUS(x) do { \
if (!x.is_contiguous()) { \
fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
exit(-1); \
} \
} while (0)
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor,
at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
const float *unknown = unknown_tensor.data<float>();
const float *known = known_tensor.data<float>();
float *dist2 = dist2_tensor.data<float>();
int *idx = idx_tensor.data<int>();
three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx);
}
void three_interpolate_wrapper_fast(int b, int c, int m, int n,
at::Tensor points_tensor,
at::Tensor idx_tensor,
at::Tensor weight_tensor,
at::Tensor out_tensor) {
const float *points = points_tensor.data<float>();
const float *weight = weight_tensor.data<float>();
float *out = out_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out);
}
void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
at::Tensor grad_out_tensor,
at::Tensor idx_tensor,
at::Tensor weight_tensor,
at::Tensor grad_points_tensor) {
const float *grad_out = grad_out_tensor.data<float>();
const float *weight = weight_tensor.data<float>();
float *grad_points = grad_points_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points);
}
void three_nn_wrapper_stack(at::Tensor unknown_tensor,
at::Tensor unknown_batch_cnt_tensor, at::Tensor known_tensor,
at::Tensor known_batch_cnt_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor){
// unknown: (N1 + N2 ..., 3)
// unknown_batch_cnt: (batch_size), [N1, N2, ...]
// known: (M1 + M2 ..., 3)
// known_batch_cnt: (batch_size), [M1, M2, ...]
// Return:
// dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors
// idx: (N1 + N2 ..., 3) index of the three nearest neighbors
CHECK_INPUT(unknown_tensor);
CHECK_INPUT(unknown_batch_cnt_tensor);
CHECK_INPUT(known_tensor);
CHECK_INPUT(known_batch_cnt_tensor);
CHECK_INPUT(dist2_tensor);
CHECK_INPUT(idx_tensor);
int batch_size = unknown_batch_cnt_tensor.size(0);
int N = unknown_tensor.size(0);
int M = known_tensor.size(0);
const float *unknown = unknown_tensor.data<float>();
const int *unknown_batch_cnt = unknown_batch_cnt_tensor.data<int>();
const float *known = known_tensor.data<float>();
const int *known_batch_cnt = known_batch_cnt_tensor.data<int>();
float *dist2 = dist2_tensor.data<float>();
int *idx = idx_tensor.data<int>();
three_nn_kernel_launcher_stack(batch_size, N, M, unknown, unknown_batch_cnt, known, known_batch_cnt, dist2, idx);
}
void three_interpolate_wrapper_stack(at::Tensor features_tensor,
at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor) {
// features_tensor: (M1 + M2 ..., C)
// idx_tensor: [N1 + N2 ..., 3]
// weight_tensor: [N1 + N2 ..., 3]
// Return:
// out_tensor: (N1 + N2 ..., C)
CHECK_INPUT(features_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(weight_tensor);
CHECK_INPUT(out_tensor);
int N = out_tensor.size(0);
int channels = features_tensor.size(1);
const float *features = features_tensor.data<float>();
const float *weight = weight_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
float *out = out_tensor.data<float>();
three_interpolate_kernel_launcher_stack(N, channels, features, idx, weight, out);
}
void three_interpolate_grad_wrapper_stack(at::Tensor grad_out_tensor, at::Tensor idx_tensor,
at::Tensor weight_tensor, at::Tensor grad_features_tensor) {
// grad_out_tensor: (N1 + N2 ..., C)
// idx_tensor: [N1 + N2 ..., 3]
// weight_tensor: [N1 + N2 ..., 3]
// Return:
// grad_features_tensor: (M1 + M2 ..., C)
CHECK_INPUT(grad_out_tensor);
CHECK_INPUT(idx_tensor);
CHECK_INPUT(weight_tensor);
CHECK_INPUT(grad_features_tensor);
int N = grad_out_tensor.size(0);
int channels = grad_out_tensor.size(1);
const float *grad_out = grad_out_tensor.data<float>();
const float *weight = weight_tensor.data<float>();
const int *idx = idx_tensor.data<int>();
float *grad_features = grad_features_tensor.data<float>();
// printf("N=%d, channels=%d\n", N, channels);
three_interpolate_grad_kernel_launcher_stack(N, channels, grad_out, idx, weight, grad_features);
} |