File size: 1,841 Bytes
89650c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
#pragma once
#include <limits>
#ifdef __CUDA_ARCH__
#define HOST_DEVICE __host__ __device__
#else
#define HOST_DEVICE
#endif
namespace at {
template <class scalar_t>
struct BinaryAdd {
HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
return x + y;
}
HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
return 1;
}
HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
return 1;
}
};
template <class scalar_t>
struct BinaryMul {
HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
return x * y;
}
HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
return y;
}
HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
return x;
}
};
template <class scalar_t>
struct NaryAdd {
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
return result + x;
}
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
return 1;
}
static constexpr scalar_t zero = 0;
};
template <class scalar_t>
struct NaryMin {
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
return result < x ? result : x;
}
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
return result == x ? 1 : 0;
}
static constexpr scalar_t zero = std::numeric_limits<scalar_t>::max();
};
template <class scalar_t>
struct NaryMax {
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
return result > x ? result : x;
}
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
return result == x ? 1 : 0;
}
static constexpr scalar_t zero = std::numeric_limits<scalar_t>::lowest();
};
} // namespace at |