Spaces:
Running
on
T4
Running
on
T4
File size: 1,682 Bytes
1ce5e18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
#define WARP_SIZE 32
#define FULL_MASK 0xffffffff
#define OPTIMAL_THREADS 256
__global__ void index_max_cuda_kernel(
float *index_vals, // [batch_size, 32, num_block]
int *indices, // [batch_size, num_block]
float *max_vals, // [batch_size, A_num_block * 32]
float *max_vals_scatter, // [batch_size, 32, num_block]
long batch_size,
long A_num_block,
long B_num_block,
long num_block
);
__global__ void mm_to_sparse_cuda_kernel(
float *dense_A, // [batch_size, A_num_block, dim, 32]
float *dense_B, // [batch_size, B_num_block, dim, 32]
int *indices, // [batch_size, num_block]
float *sparse_C, // [batch_size, num_block, 32, 32]
long batch_size,
long A_num_block,
long B_num_block,
long dim,
long num_block
);
__global__ void sparse_dense_mm_cuda_kernel(
float *sparse_A, // [batch_size, num_block, 32, 32]
int *indices, // [batch_size, num_block]
float *dense_B, // [batch_size, B_num_block, dim, 32]
float *dense_C, // [batch_size, A_num_block, dim, 32]
long batch_size,
long A_num_block,
long B_num_block,
long dim,
long num_block
);
__global__ void reduce_sum_cuda_kernel(
float *sparse_A, // [batch_size, num_block, 32, 32]
int *indices, // [batch_size, num_block]
float *dense_C, // [batch_size, A_num_block, 32]
long batch_size,
long A_num_block,
long B_num_block,
long num_block
);
__global__ void scatter_cuda_kernel(
float *dense_A, // [batch_size, A_num_block, 32]
int *indices, // [batch_size, num_block]
float *sparse_C, // [batch_size, num_block, 32, 32]
long batch_size,
long A_num_block,
long B_num_block,
long num_block
);
|