0-hero commited on
Commit
d572127
·
verified ·
1 Parent(s): e7aa429

Add files using upload-large-folder tool

Browse files
.triton/dump/76fb48b96c75cb8e388c291a18ef9b02/triton_.ttir ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module {
2
+ tt.func public @triton__0d1d2d3d4d5d6de7de(%arg0: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} {
3
+ %cst = arith.constant dense<0.000000e+00> : tensor<2x128xbf16>
4
+ %cst_0 = arith.constant 0.000000e+00 : f32
5
+ %cst_1 = arith.constant dense<1.000000e+00> : tensor<2x128xf32>
6
+ %c256_i32 = arith.constant 256 : i32
7
+ %c128_i32 = arith.constant 128 : i32
8
+ %c0_i32 = arith.constant 0 : i32
9
+ %cst_2 = arith.constant dense<256> : tensor<2x1xi64>
10
+ %cst_3 = arith.constant dense<0> : tensor<2x1xi64>
11
+ %cst_4 = arith.constant dense<50257> : tensor<2x1xi64>
12
+ %cst_5 = arith.constant dense<9.99999974E-6> : tensor<2x1xf32>
13
+ %cst_6 = arith.constant dense<2.560000e+02> : tensor<2x1xf32>
14
+ %cst_7 = arith.constant dense<0.000000e+00> : tensor<1x128xf32>
15
+ %cst_8 = arith.constant dense<0.000000e+00> : tensor<2x128xf32>
16
+ %cst_9 = arith.constant dense<256> : tensor<2x1xi32>
17
+ %cst_10 = arith.constant dense<256> : tensor<1x128xi32>
18
+ %cst_11 = arith.constant dense<512> : tensor<2x1xi32>
19
+ %c2_i32 = arith.constant 2 : i32
20
+ %0 = tt.get_program_id x : i32
21
+ %1 = arith.muli %0, %c2_i32 : i32
22
+ %2 = tt.make_range {end = 2 : i32, start = 0 : i32} : tensor<2xi32>
23
+ %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<2xi32>) -> tensor<2x1xi32>
24
+ %4 = tt.splat %1 : (i32) -> tensor<2x1xi32>
25
+ %5 = arith.addi %4, %3 : tensor<2x1xi32>
26
+ %6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
27
+ %7 = tt.expand_dims %6 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32>
28
+ %8 = tt.splat %arg0 : (!tt.ptr<i64, 1>) -> tensor<2x1x!tt.ptr<i64, 1>>
29
+ %9 = tt.addptr %8, %5 : tensor<2x1x!tt.ptr<i64, 1>>, tensor<2x1xi32>
30
+ %10 = tt.load %9 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<2x1xi64>
31
+ %11 = arith.remsi %5, %cst_11 : tensor<2x1xi32>
32
+ %12 = arith.muli %11, %cst_9 : tensor<2x1xi32>
33
+ %13 = tt.broadcast %12 : (tensor<2x1xi32>) -> tensor<2x128xi32>
34
+ %14 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<2x128x!tt.ptr<f32, 1>>
35
+ %15 = arith.muli %5, %cst_9 : tensor<2x1xi32>
36
+ %16 = tt.broadcast %15 : (tensor<2x1xi32>) -> tensor<2x128xi32>
37
+ %17 = tt.splat %arg3 : (!tt.ptr<bf16, 1>) -> tensor<2x128x!tt.ptr<bf16, 1>>
38
+ %18 = arith.addi %10, %cst_4 : tensor<2x1xi64>
39
+ %19 = arith.cmpi slt, %10, %cst_3 : tensor<2x1xi64>
40
+ %20 = arith.select %19, %18, %10 : tensor<2x1xi1>, tensor<2x1xi64>
41
+ %21 = arith.cmpi sge, %20, %cst_3 : tensor<2x1xi64>
42
+ %22 = arith.cmpi slt, %20, %cst_4 : tensor<2x1xi64>
43
+ %23 = arith.andi %21, %22 : tensor<2x1xi1>
44
+ %24 = arith.muli %20, %cst_2 : tensor<2x1xi64>
45
+ %25 = tt.broadcast %24 : (tensor<2x1xi64>) -> tensor<2x128xi64>
46
+ %26 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<2x128x!tt.ptr<f32, 1>>
47
+ %27:3 = scf.for %arg8 = %c0_i32 to %c256_i32 step %c128_i32 iter_args(%arg9 = %cst_8, %arg10 = %cst_8, %arg11 = %cst_8) -> (tensor<2x128xf32>, tensor<2x128xf32>, tensor<2x128xf32>) : i32 {
48
+ %51 = tt.splat %arg8 : (i32) -> tensor<1x128xi32>
49
+ %52 = arith.addi %51, %7 : tensor<1x128xi32>
50
+ %53 = arith.cmpi slt, %52, %cst_10 : tensor<1x128xi32>
51
+ %54 = tt.broadcast %52 : (tensor<1x128xi32>) -> tensor<2x128xi32>
52
+ %55 = arith.addi %54, %13 : tensor<2x128xi32>
53
+ %56 = tt.addptr %14, %55 : tensor<2x128x!tt.ptr<f32, 1>>, tensor<2x128xi32>
54
+ %57 = tt.broadcast %53 : (tensor<1x128xi1>) -> tensor<2x128xi1>
55
+ %58 = tt.load %56, %57, %cst_8 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<2x128xf32>
56
+ %59 = arith.addi %54, %16 : tensor<2x128xi32>
57
+ %60 = tt.addptr %17, %59 : tensor<2x128x!tt.ptr<bf16, 1>>, tensor<2x128xi32>
58
+ %61 = tt.load %60, %57, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<2x128xbf16>
59
+ %62 = arith.extf %61 : tensor<2x128xbf16> to tensor<2x128xf32>
60
+ tt.assert %23, "index out of bounds: 0 <= tmp3 < 50257", "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", "<module>", 1892 : tensor<2x1xi1>
61
+ %63 = arith.extsi %52 : tensor<1x128xi32> to tensor<1x128xi64>
62
+ %64 = tt.broadcast %63 : (tensor<1x128xi64>) -> tensor<2x128xi64>
63
+ %65 = arith.addi %64, %25 : tensor<2x128xi64>
64
+ %66 = tt.addptr %26, %65 : tensor<2x128x!tt.ptr<f32, 1>>, tensor<2x128xi64>
65
+ %67 = tt.load %66, %57, %cst_8 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<2x128xf32>
66
+ %68 = arith.addf %67, %58 : tensor<2x128xf32>
67
+ %69 = arith.addf %68, %62 : tensor<2x128xf32>
68
+ %70 = arith.subf %69, %arg9 : tensor<2x128xf32>
69
+ %71 = arith.addf %arg11, %cst_1 : tensor<2x128xf32>
70
+ %72 = arith.divf %70, %71 : tensor<2x128xf32>
71
+ %73 = arith.addf %arg9, %72 : tensor<2x128xf32>
72
+ %74 = arith.subf %69, %73 : tensor<2x128xf32>
73
+ %75 = arith.mulf %70, %74 : tensor<2x128xf32>
74
+ %76 = arith.addf %arg10, %75 : tensor<2x128xf32>
75
+ %77 = arith.select %57, %73, %arg9 : tensor<2x128xi1>, tensor<2x128xf32>
76
+ %78 = arith.select %57, %76, %arg10 : tensor<2x128xi1>, tensor<2x128xf32>
77
+ %79 = arith.select %57, %71, %arg11 : tensor<2x128xi1>, tensor<2x128xf32>
78
+ scf.yield %77, %78, %79 : tensor<2x128xf32>, tensor<2x128xf32>, tensor<2x128xf32>
79
+ }
80
+ %28:3 = "tt.reduce"(%27#0, %27#1, %27#2) <{axis = 1 : i32}> ({
81
+ ^bb0(%arg8: f32, %arg9: f32, %arg10: f32, %arg11: f32, %arg12: f32, %arg13: f32):
82
+ %51 = arith.subf %arg11, %arg8 : f32
83
+ %52 = arith.addf %arg10, %arg13 : f32
84
+ %53 = arith.cmpf oeq, %52, %cst_0 : f32
85
+ %54 = arith.divf %arg13, %52 : f32
86
+ %55 = arith.select %53, %cst_0, %54 : f32
87
+ %56 = arith.mulf %51, %55 : f32
88
+ %57 = arith.addf %arg8, %56 : f32
89
+ %58 = arith.addf %arg9, %arg12 : f32
90
+ %59 = arith.mulf %51, %51 : f32
91
+ %60 = arith.mulf %59, %arg10 : f32
92
+ %61 = arith.mulf %60, %55 : f32
93
+ %62 = arith.addf %58, %61 : f32
94
+ tt.reduce.return %57, %62, %52 : f32, f32, f32
95
+ }) : (tensor<2x128xf32>, tensor<2x128xf32>, tensor<2x128xf32>) -> (tensor<2xf32>, tensor<2xf32>, tensor<2xf32>)
96
+ %29 = tt.expand_dims %28#0 {axis = 1 : i32} : (tensor<2xf32>) -> tensor<2x1xf32>
97
+ %30 = tt.expand_dims %28#1 {axis = 1 : i32} : (tensor<2xf32>) -> tensor<2x1xf32>
98
+ %31 = arith.muli %11, %cst_9 : tensor<2x1xi32>
99
+ %32 = tt.broadcast %31 : (tensor<2x1xi32>) -> tensor<2x128xi32>
100
+ %33 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<2x128x!tt.ptr<f32, 1>>
101
+ %34 = arith.muli %5, %cst_9 : tensor<2x1xi32>
102
+ %35 = tt.broadcast %34 : (tensor<2x1xi32>) -> tensor<2x128xi32>
103
+ %36 = tt.splat %arg3 : (!tt.ptr<bf16, 1>) -> tensor<2x128x!tt.ptr<bf16, 1>>
104
+ %37 = tt.splat %arg4 : (!tt.ptr<f32, 1>) -> tensor<1x128x!tt.ptr<f32, 1>>
105
+ %38 = arith.addi %10, %cst_4 : tensor<2x1xi64>
106
+ %39 = arith.cmpi slt, %10, %cst_3 : tensor<2x1xi64>
107
+ %40 = arith.select %39, %38, %10 : tensor<2x1xi1>, tensor<2x1xi64>
108
+ %41 = arith.cmpi sge, %40, %cst_3 : tensor<2x1xi64>
109
+ %42 = arith.cmpi slt, %40, %cst_4 : tensor<2x1xi64>
110
+ %43 = arith.andi %41, %42 : tensor<2x1xi1>
111
+ %44 = arith.muli %40, %cst_2 : tensor<2x1xi64>
112
+ %45 = tt.broadcast %44 : (tensor<2x1xi64>) -> tensor<2x128xi64>
113
+ %46 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<2x128x!tt.ptr<f32, 1>>
114
+ %47 = tt.broadcast %29 : (tensor<2x1xf32>) -> tensor<2x128xf32>
115
+ %48 = arith.divf %30, %cst_6 : tensor<2x1xf32>
116
+ %49 = arith.addf %48, %cst_5 : tensor<2x1xf32>
117
+ %50 = tt.splat %arg5 : (!tt.ptr<bf16, 1>) -> tensor<2x128x!tt.ptr<bf16, 1>>
118
+ scf.for %arg8 = %c0_i32 to %c256_i32 step %c128_i32 : i32 {
119
+ %51 = tt.splat %arg8 : (i32) -> tensor<1x128xi32>
120
+ %52 = arith.addi %51, %7 : tensor<1x128xi32>
121
+ %53 = arith.cmpi slt, %52, %cst_10 : tensor<1x128xi32>
122
+ %54 = tt.broadcast %52 : (tensor<1x128xi32>) -> tensor<2x128xi32>
123
+ %55 = arith.addi %54, %32 : tensor<2x128xi32>
124
+ %56 = tt.addptr %33, %55 : tensor<2x128x!tt.ptr<f32, 1>>, tensor<2x128xi32>
125
+ %57 = tt.broadcast %53 : (tensor<1x128xi1>) -> tensor<2x128xi1>
126
+ %58 = tt.load %56, %57, %cst_8 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<2x128xf32>
127
+ %59 = arith.addi %54, %35 : tensor<2x128xi32>
128
+ %60 = tt.addptr %36, %59 : tensor<2x128x!tt.ptr<bf16, 1>>, tensor<2x128xi32>
129
+ %61 = tt.load %60, %57, %cst {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<2x128xbf16>
130
+ %62 = arith.extf %61 : tensor<2x128xbf16> to tensor<2x128xf32>
131
+ %63 = tt.addptr %37, %52 : tensor<1x128x!tt.ptr<f32, 1>>, tensor<1x128xi32>
132
+ %64 = tt.load %63, %53, %cst_7 {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x128xf32>
133
+ tt.assert %43, "index out of bounds: 0 <= tmp16 < 50257", "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", "<module>", 1892 : tensor<2x1xi1>
134
+ %65 = arith.extsi %52 : tensor<1x128xi32> to tensor<1x128xi64>
135
+ %66 = tt.broadcast %65 : (tensor<1x128xi64>) -> tensor<2x128xi64>
136
+ %67 = arith.addi %66, %45 : tensor<2x128xi64>
137
+ %68 = tt.addptr %46, %67 : tensor<2x128x!tt.ptr<f32, 1>>, tensor<2x128xi64>
138
+ %69 = tt.load %68, %57, %cst_8 {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<2x128xf32>
139
+ %70 = arith.addf %69, %58 : tensor<2x128xf32>
140
+ %71 = arith.addf %70, %62 : tensor<2x128xf32>
141
+ %72 = arith.subf %71, %47 : tensor<2x128xf32>
142
+ %73 = tt.extern_elementwise %49 {libname = "libdevice", libpath = "/usr/local/lib/python3.10/dist-packages/triton/language/../third_party/cuda/lib/libdevice.10.bc", pure = true, symbol = "__nv_rsqrtf"} : (tensor<2x1xf32>) -> tensor<2x1xf32>
143
+ %74 = tt.broadcast %73 : (tensor<2x1xf32>) -> tensor<2x128xf32>
144
+ %75 = arith.mulf %72, %74 : tensor<2x128xf32>
145
+ %76 = tt.broadcast %64 : (tensor<1x128xf32>) -> tensor<2x128xf32>
146
+ %77 = arith.mulf %75, %76 : tensor<2x128xf32>
147
+ %78 = tt.addptr %50, %59 : tensor<2x128x!tt.ptr<bf16, 1>>, tensor<2x128xi32>
148
+ %79 = arith.truncf %77 : tensor<2x128xf32> to tensor<2x128xbf16>
149
+ tt.store %78, %79, %57 {cache = 1 : i32, evict = 1 : i32} : tensor<2x128xbf16>
150
+ }
151
+ tt.return
152
+ }
153
+ }
.triton/dump/89f8cc1079aa03024e56dc2aee42813a/triton_.ttir ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module {
2
+ tt.func public @triton__0d1d2d3d4d5d6e7de(%arg0: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<i64, 1> {tt.divisibility = 16 : i32}, %arg6: i64 {tt.max_divisibility = 8 : i32}, %arg7: i64 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}) attributes {noinline = false} {
3
+ %c8_i64 = arith.constant 8 : i64
4
+ %c7680_i64 = arith.constant 7680 : i64
5
+ %c385973760_i64 = arith.constant 385973760 : i64
6
+ %cst = arith.constant dense<0.000000e+00> : tensor<1x2048xbf16>
7
+ %cst_0 = arith.constant dense<50257> : tensor<1x2048xi64>
8
+ %cst_1 = arith.constant dense<7680> : tensor<1x2048xi64>
9
+ %c2048_i32 = arith.constant 2048 : i32
10
+ %c7680_i32 = arith.constant 7680 : i32
11
+ %c0_i32 = arith.constant 0 : i32
12
+ %cst_2 = arith.constant dense<-1> : tensor<1x2048xi64>
13
+ %cst_3 = arith.constant dense<0> : tensor<1x2048xi64>
14
+ %cst_4 = arith.constant dense<0.000000e+00> : tensor<1x2048xf32>
15
+ %0 = tt.get_program_id x : i32
16
+ %1 = arith.extsi %0 : i32 to i64
17
+ %2 = arith.cmpi slt, %1, %c8_i64 : i64
18
+ %3 = tt.splat %2 : (i1) -> tensor<1x1xi1>
19
+ %4 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32>
20
+ %5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<2048xi32>) -> tensor<1x2048xi32>
21
+ %6 = arith.extsi %5 : tensor<1x2048xi32> to tensor<1x2048xi64>
22
+ %7 = arith.muli %1, %c7680_i64 : i64
23
+ %8 = tt.splat %7 : (i64) -> tensor<1x2048xi64>
24
+ %9 = tt.splat %arg0 : (!tt.ptr<i64, 1>) -> tensor<1x2048x!tt.ptr<i64, 1>>
25
+ %10 = tt.splat %2 : (i1) -> tensor<1x2048xi1>
26
+ %11 = tt.splat %arg2 : (!tt.ptr<f32, 1>) -> tensor<1x2048x!tt.ptr<f32, 1>>
27
+ %12 = tt.splat %arg3 : (!tt.ptr<f32, 1>) -> tensor<1x2048x!tt.ptr<f32, 1>>
28
+ %13 = arith.muli %1, %c385973760_i64 : i64
29
+ %14 = tt.splat %13 : (i64) -> tensor<1x2048xi64>
30
+ %15 = tt.splat %arg1 : (!tt.ptr<bf16, 1>) -> tensor<1x2048x!tt.ptr<bf16, 1>>
31
+ %16:2 = scf.for %arg8 = %c0_i32 to %c7680_i32 step %c2048_i32 iter_args(%arg9 = %cst_4, %arg10 = %cst_3) -> (tensor<1x2048xf32>, tensor<1x2048xi64>) : i32 {
32
+ %25 = arith.extsi %arg8 : i32 to i64
33
+ %26 = tt.splat %25 : (i64) -> tensor<1x2048xi64>
34
+ %27 = arith.addi %26, %6 : tensor<1x2048xi64>
35
+ %28 = arith.cmpi slt, %27, %cst_1 : tensor<1x2048xi64>
36
+ %29 = arith.addi %27, %8 : tensor<1x2048xi64>
37
+ %30 = tt.addptr %9, %29 : tensor<1x2048x!tt.ptr<i64, 1>>, tensor<1x2048xi64>
38
+ %31 = arith.andi %28, %10 : tensor<1x2048xi1>
39
+ %32 = tt.load %30, %31, %cst_3 {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<1x2048xi64>
40
+ %33 = tt.addptr %11, %29 : tensor<1x2048x!tt.ptr<f32, 1>>, tensor<1x2048xi64>
41
+ %34 = tt.load %33, %31, %cst_4 {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<1x2048xf32>
42
+ %35 = tt.addptr %12, %29 : tensor<1x2048x!tt.ptr<f32, 1>>, tensor<1x2048xi64>
43
+ %36 = tt.load %35, %31, %cst_4 {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<1x2048xf32>
44
+ %37 = arith.cmpi ne, %32, %cst_2 : tensor<1x2048xi64>
45
+ %38 = arith.select %37, %32, %cst_3 : tensor<1x2048xi1>, tensor<1x2048xi64>
46
+ %39 = arith.addi %38, %cst_0 : tensor<1x2048xi64>
47
+ %40 = arith.cmpi slt, %38, %cst_3 : tensor<1x2048xi64>
48
+ %41 = arith.select %40, %39, %38 : tensor<1x2048xi1>, tensor<1x2048xi64>
49
+ %42 = arith.cmpi sge, %41, %cst_3 : tensor<1x2048xi64>
50
+ %43 = arith.cmpi slt, %41, %cst_0 : tensor<1x2048xi64>
51
+ %44 = arith.andi %42, %43 : tensor<1x2048xi1>
52
+ tt.assert %44, "index out of bounds: 0 <= tmp7 < 50257", "<frozen importlib._bootstrap_external>", "_call_with_frames_removed", 883 : tensor<1x2048xi1>
53
+ %45 = arith.muli %27, %cst_0 : tensor<1x2048xi64>
54
+ %46 = arith.addi %41, %45 : tensor<1x2048xi64>
55
+ %47 = arith.addi %46, %14 : tensor<1x2048xi64>
56
+ %48 = tt.addptr %15, %47 : tensor<1x2048x!tt.ptr<bf16, 1>>, tensor<1x2048xi64>
57
+ %49 = tt.load %48, %31, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x2048xbf16>
58
+ %50 = arith.extf %49 : tensor<1x2048xbf16> to tensor<1x2048xf32>
59
+ %51 = arith.subf %50, %34 : tensor<1x2048xf32>
60
+ %52 = math.log %36 : tensor<1x2048xf32>
61
+ %53 = arith.subf %51, %52 : tensor<1x2048xf32>
62
+ %54 = arith.subf %cst_4, %53 : tensor<1x2048xf32>
63
+ %55 = arith.select %37, %54, %cst_4 : tensor<1x2048xi1>, tensor<1x2048xf32>
64
+ %56 = arith.addf %arg9, %55 : tensor<1x2048xf32>
65
+ %57 = arith.select %31, %56, %arg9 : tensor<1x2048xi1>, tensor<1x2048xf32>
66
+ %58 = arith.extui %37 : tensor<1x2048xi1> to tensor<1x2048xi64>
67
+ %59 = arith.addi %arg10, %58 : tensor<1x2048xi64>
68
+ %60 = arith.select %31, %59, %arg10 : tensor<1x2048xi1>, tensor<1x2048xi64>
69
+ scf.yield %57, %60 : tensor<1x2048xf32>, tensor<1x2048xi64>
70
+ }
71
+ %17 = "tt.reduce"(%16#0) <{axis = 1 : i32}> ({
72
+ ^bb0(%arg8: f32, %arg9: f32):
73
+ %25 = arith.addf %arg8, %arg9 : f32
74
+ tt.reduce.return %25 : f32
75
+ }) : (tensor<1x2048xf32>) -> tensor<1xf32>
76
+ %18 = tt.expand_dims %17 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
77
+ %19 = tt.addptr %arg4, %1 : !tt.ptr<f32, 1>, i64
78
+ %20 = tt.splat %19 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
79
+ tt.store %20, %18, %3 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>
80
+ %21 = "tt.reduce"(%16#1) <{axis = 1 : i32}> ({
81
+ ^bb0(%arg8: i64, %arg9: i64):
82
+ %25 = arith.addi %arg8, %arg9 : i64
83
+ tt.reduce.return %25 : i64
84
+ }) : (tensor<1x2048xi64>) -> tensor<1xi64>
85
+ %22 = tt.expand_dims %21 {axis = 1 : i32} : (tensor<1xi64>) -> tensor<1x1xi64>
86
+ %23 = tt.addptr %arg5, %1 : !tt.ptr<i64, 1>, i64
87
+ %24 = tt.splat %23 : (!tt.ptr<i64, 1>) -> tensor<1x1x!tt.ptr<i64, 1>>
88
+ tt.store %24, %22, %3 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xi64>
89
+ tt.return
90
+ }
91
+ }