smpanaro's picture
Faster argmax by skipping logit concat
f519906
program(1.0)
[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
{
func main<ios16>(tensor<fp16, [1, 64, 16384]> logits_0, tensor<fp16, [1, 64, 16384]> logits_1, tensor<fp16, [1, 64, 16384]> logits_2, tensor<fp16, [1, 64, 16384]> logits_3, tensor<fp16, [1, 64, 16384]> logits_4, tensor<fp16, [1, 64, 16384]> logits_5, tensor<fp16, [1, 64, 16384]> logits_6, tensor<fp16, [1, 64, 13568]> logits_7) {
tensor<int32, [1]> chunk_size = const()[name = tensor<string, []>("chunk_size"), val = tensor<int32, [1]>([16384])];
tensor<int32, []> var_12 = const()[name = tensor<string, []>("op_12"), val = tensor<int32, []>(1)];
tensor<int32, []> var_16_axis_0 = const()[name = tensor<string, []>("op_16_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_16_ascending_0 = const()[name = tensor<string, []>("op_16_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_16_sort_0 = const()[name = tensor<string, []>("op_16_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_16_return_indices_0 = const()[name = tensor<string, []>("op_16_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_16_cast_fp16_0, tensor<int32, [1, 64, 1]> var_16_cast_fp16_1 = topk(ascending = var_16_ascending_0, axis = var_16_axis_0, k = var_12, return_indices = var_16_return_indices_0, sort = var_16_sort_0, x = logits_0)[name = tensor<string, []>("op_16_cast_fp16")];
tensor<int32, []> var_22 = const()[name = tensor<string, []>("op_22"), val = tensor<int32, []>(1)];
tensor<int32, []> var_26_axis_0 = const()[name = tensor<string, []>("op_26_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_26_ascending_0 = const()[name = tensor<string, []>("op_26_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_26_sort_0 = const()[name = tensor<string, []>("op_26_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_26_return_indices_0 = const()[name = tensor<string, []>("op_26_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_26_cast_fp16_0, tensor<int32, [1, 64, 1]> var_26_cast_fp16_1 = topk(ascending = var_26_ascending_0, axis = var_26_axis_0, k = var_22, return_indices = var_26_return_indices_0, sort = var_26_sort_0, x = logits_1)[name = tensor<string, []>("op_26_cast_fp16")];
tensor<int32, [1, 64, 1]> var_31 = add(x = var_26_cast_fp16_1, y = chunk_size)[name = tensor<string, []>("op_31")];
tensor<int32, []> var_32 = const()[name = tensor<string, []>("op_32"), val = tensor<int32, []>(1)];
tensor<int32, []> var_36_axis_0 = const()[name = tensor<string, []>("op_36_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_36_ascending_0 = const()[name = tensor<string, []>("op_36_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_36_sort_0 = const()[name = tensor<string, []>("op_36_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_36_return_indices_0 = const()[name = tensor<string, []>("op_36_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_36_cast_fp16_0, tensor<int32, [1, 64, 1]> var_36_cast_fp16_1 = topk(ascending = var_36_ascending_0, axis = var_36_axis_0, k = var_32, return_indices = var_36_return_indices_0, sort = var_36_sort_0, x = logits_2)[name = tensor<string, []>("op_36_cast_fp16")];
tensor<int32, [1]> var_39 = const()[name = tensor<string, []>("op_39"), val = tensor<int32, [1]>([32768])];
tensor<int32, [1, 64, 1]> var_41 = add(x = var_36_cast_fp16_1, y = var_39)[name = tensor<string, []>("op_41")];
tensor<int32, []> var_42 = const()[name = tensor<string, []>("op_42"), val = tensor<int32, []>(1)];
tensor<int32, []> var_46_axis_0 = const()[name = tensor<string, []>("op_46_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_46_ascending_0 = const()[name = tensor<string, []>("op_46_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_46_sort_0 = const()[name = tensor<string, []>("op_46_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_46_return_indices_0 = const()[name = tensor<string, []>("op_46_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_46_cast_fp16_0, tensor<int32, [1, 64, 1]> var_46_cast_fp16_1 = topk(ascending = var_46_ascending_0, axis = var_46_axis_0, k = var_42, return_indices = var_46_return_indices_0, sort = var_46_sort_0, x = logits_3)[name = tensor<string, []>("op_46_cast_fp16")];
tensor<int32, [1]> var_49 = const()[name = tensor<string, []>("op_49"), val = tensor<int32, [1]>([49152])];
tensor<int32, [1, 64, 1]> var_51 = add(x = var_46_cast_fp16_1, y = var_49)[name = tensor<string, []>("op_51")];
tensor<int32, []> var_52 = const()[name = tensor<string, []>("op_52"), val = tensor<int32, []>(1)];
tensor<int32, []> var_56_axis_0 = const()[name = tensor<string, []>("op_56_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_56_ascending_0 = const()[name = tensor<string, []>("op_56_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_56_sort_0 = const()[name = tensor<string, []>("op_56_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_56_return_indices_0 = const()[name = tensor<string, []>("op_56_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_56_cast_fp16_0, tensor<int32, [1, 64, 1]> var_56_cast_fp16_1 = topk(ascending = var_56_ascending_0, axis = var_56_axis_0, k = var_52, return_indices = var_56_return_indices_0, sort = var_56_sort_0, x = logits_4)[name = tensor<string, []>("op_56_cast_fp16")];
tensor<int32, [1]> var_59 = const()[name = tensor<string, []>("op_59"), val = tensor<int32, [1]>([65536])];
tensor<int32, [1, 64, 1]> var_61 = add(x = var_56_cast_fp16_1, y = var_59)[name = tensor<string, []>("op_61")];
tensor<int32, []> var_62 = const()[name = tensor<string, []>("op_62"), val = tensor<int32, []>(1)];
tensor<int32, []> var_66_axis_0 = const()[name = tensor<string, []>("op_66_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_66_ascending_0 = const()[name = tensor<string, []>("op_66_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_66_sort_0 = const()[name = tensor<string, []>("op_66_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_66_return_indices_0 = const()[name = tensor<string, []>("op_66_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_66_cast_fp16_0, tensor<int32, [1, 64, 1]> var_66_cast_fp16_1 = topk(ascending = var_66_ascending_0, axis = var_66_axis_0, k = var_62, return_indices = var_66_return_indices_0, sort = var_66_sort_0, x = logits_5)[name = tensor<string, []>("op_66_cast_fp16")];
tensor<int32, [1]> var_69 = const()[name = tensor<string, []>("op_69"), val = tensor<int32, [1]>([81920])];
tensor<int32, [1, 64, 1]> var_71 = add(x = var_66_cast_fp16_1, y = var_69)[name = tensor<string, []>("op_71")];
tensor<int32, []> var_72 = const()[name = tensor<string, []>("op_72"), val = tensor<int32, []>(1)];
tensor<int32, []> var_76_axis_0 = const()[name = tensor<string, []>("op_76_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_76_ascending_0 = const()[name = tensor<string, []>("op_76_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_76_sort_0 = const()[name = tensor<string, []>("op_76_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_76_return_indices_0 = const()[name = tensor<string, []>("op_76_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_76_cast_fp16_0, tensor<int32, [1, 64, 1]> var_76_cast_fp16_1 = topk(ascending = var_76_ascending_0, axis = var_76_axis_0, k = var_72, return_indices = var_76_return_indices_0, sort = var_76_sort_0, x = logits_6)[name = tensor<string, []>("op_76_cast_fp16")];
tensor<int32, [1]> var_79 = const()[name = tensor<string, []>("op_79"), val = tensor<int32, [1]>([98304])];
tensor<int32, [1, 64, 1]> var_81 = add(x = var_76_cast_fp16_1, y = var_79)[name = tensor<string, []>("op_81")];
tensor<int32, []> var_82 = const()[name = tensor<string, []>("op_82"), val = tensor<int32, []>(1)];
tensor<int32, []> cv_axis_0 = const()[name = tensor<string, []>("cv_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> cv_ascending_0 = const()[name = tensor<string, []>("cv_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> cv_sort_0 = const()[name = tensor<string, []>("cv_sort_0"), val = tensor<bool, []>(false)];
tensor<bool, []> cv_return_indices_0 = const()[name = tensor<string, []>("cv_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> cv_cast_fp16_0, tensor<int32, [1, 64, 1]> cv_cast_fp16_1 = topk(ascending = cv_ascending_0, axis = cv_axis_0, k = var_82, return_indices = cv_return_indices_0, sort = cv_sort_0, x = logits_7)[name = tensor<string, []>("cv_cast_fp16")];
tensor<int32, [1]> var_89 = const()[name = tensor<string, []>("op_89"), val = tensor<int32, [1]>([114688])];
tensor<int32, [1, 64, 1]> var_91 = add(x = cv_cast_fp16_1, y = var_89)[name = tensor<string, []>("op_91")];
tensor<int32, []> var_93 = const()[name = tensor<string, []>("op_93"), val = tensor<int32, []>(-1)];
tensor<bool, []> values_interleave_0 = const()[name = tensor<string, []>("values_interleave_0"), val = tensor<bool, []>(false)];
tensor<fp16, [1, 64, 8]> values_cast_fp16 = concat(axis = var_93, interleave = values_interleave_0, values = (var_16_cast_fp16_0, var_26_cast_fp16_0, var_36_cast_fp16_0, var_46_cast_fp16_0, var_56_cast_fp16_0, var_66_cast_fp16_0, var_76_cast_fp16_0, cv_cast_fp16_0))[name = tensor<string, []>("values_cast_fp16")];
tensor<int32, []> var_96 = const()[name = tensor<string, []>("op_96"), val = tensor<int32, []>(-1)];
tensor<bool, []> indices_interleave_0 = const()[name = tensor<string, []>("indices_interleave_0"), val = tensor<bool, []>(false)];
tensor<int32, [1, 64, 8]> indices = concat(axis = var_96, interleave = indices_interleave_0, values = (var_16_cast_fp16_1, var_31, var_41, var_51, var_61, var_71, var_81, var_91))[name = tensor<string, []>("indices")];
tensor<int32, []> var_98 = const()[name = tensor<string, []>("op_98"), val = tensor<int32, []>(1)];
tensor<int32, []> var_102_axis_0 = const()[name = tensor<string, []>("op_102_axis_0"), val = tensor<int32, []>(-1)];
tensor<bool, []> var_102_ascending_0 = const()[name = tensor<string, []>("op_102_ascending_0"), val = tensor<bool, []>(false)];
tensor<bool, []> var_102_sort_0 = const()[name = tensor<string, []>("op_102_sort_0"), val = tensor<bool, []>(true)];
tensor<bool, []> var_102_return_indices_0 = const()[name = tensor<string, []>("op_102_return_indices_0"), val = tensor<bool, []>(true)];
tensor<fp16, [1, 64, 1]> var_102_cast_fp16_0, tensor<int32, [1, 64, 1]> var_102_cast_fp16_1 = topk(ascending = var_102_ascending_0, axis = var_102_axis_0, k = var_98, return_indices = var_102_return_indices_0, sort = var_102_sort_0, x = values_cast_fp16)[name = tensor<string, []>("op_102_cast_fp16")];
tensor<int32, []> var_104 = const()[name = tensor<string, []>("op_104"), val = tensor<int32, []>(-1)];
tensor<int32, [1, 64, 1]> var_106 = gather_along_axis(axis = var_104, indices = var_102_cast_fp16_1, x = indices)[name = tensor<string, []>("op_106")];
tensor<int32, [1]> var_108_axes_0 = const()[name = tensor<string, []>("op_108_axes_0"), val = tensor<int32, [1]>([-1])];
tensor<int32, [1, 64]> argmax = squeeze(axes = var_108_axes_0, x = var_106)[name = tensor<string, []>("op_108")];
} -> (argmax);
}