File size: 6,156 Bytes
89650c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
#pragma once
#include <tuple>
#include <torch/extension.h>
//#include <ATen/SparseTensorUtils.h>
#include <ATen/native/SparseTensorUtils.h>
namespace at {
using namespace at::sparse;
void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg);
void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
const TensorArg &output_arg, const TensorArg &output_grad_arg);
Tensor ind2ptr(const Tensor &index, int size);
Tensor rspmm_add_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cpu(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_min_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cpu(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_max_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cpu(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_add_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cpu(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_min_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cpu(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_max_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cpu(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
#ifdef CUDA_OP
Tensor rspmm_add_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cuda(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_min_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cuda(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_max_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cuda(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_add_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cuda(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_min_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cuda(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
Tensor rspmm_max_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
const Tensor &relation, const Tensor &input);
std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cuda(
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
const Tensor &input, const Tensor &output, const Tensor &output_grad);
#endif
} // namespace at |