0-hero's picture
Add files using upload-large-folder tool
f67f72f verified
module {
tt.func public @triton__0d1d2d3d4de5(%arg0: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<bf16, 1> {tt.divisibility = 16 : i32}, %arg4: i64 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i64) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<1x2048xbf16>
%c50257_i64 = arith.constant 50257 : i64
%cst_0 = arith.constant dense<true> : tensor<1x2048xi1>
%c50257_i32 = arith.constant 50257 : i32
%c2048_i32 = arith.constant 2048 : i32
%c0_i32 = arith.constant 0 : i32
%cst_1 = arith.constant dense<50257> : tensor<1x2048xi64>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<1x2048xf32>
%cst_3 = arith.constant dense<0xFF800000> : tensor<1x2048xf32>
%0 = tt.get_program_id x : i32
%1 = arith.extsi %0 : i32 to i64
%2 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32>
%3 = tt.expand_dims %2 {axis = 0 : i32} : (tensor<2048xi32>) -> tensor<1x2048xi32>
%4 = arith.extsi %3 : tensor<1x2048xi32> to tensor<1x2048xi64>
%5 = arith.muli %1, %c50257_i64 : i64
%6 = tt.splat %5 : (i64) -> tensor<1x2048xi64>
%7 = tt.splat %arg0 : (!tt.ptr<bf16, 1>) -> tensor<1x2048x!tt.ptr<bf16, 1>>
%8 = scf.for %arg6 = %c0_i32 to %c50257_i32 step %c2048_i32 iter_args(%arg7 = %cst_3) -> (tensor<1x2048xf32>) : i32 {
%29 = arith.extsi %arg6 : i32 to i64
%30 = tt.splat %29 : (i64) -> tensor<1x2048xi64>
%31 = arith.addi %30, %4 : tensor<1x2048xi64>
%32 = arith.cmpi slt, %31, %cst_1 : tensor<1x2048xi64>
%33 = arith.addi %31, %6 : tensor<1x2048xi64>
%34 = tt.addptr %7, %33 : tensor<1x2048x!tt.ptr<bf16, 1>>, tensor<1x2048xi64>
%35 = tt.load %34, %32, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x2048xbf16>
%36 = arith.extf %35 : tensor<1x2048xbf16> to tensor<1x2048xf32>
%37 = arith.cmpf ogt, %arg7, %36 : tensor<1x2048xf32>
%38 = arith.cmpf une, %arg7, %arg7 : tensor<1x2048xf32>
%39 = arith.ori %37, %38 : tensor<1x2048xi1>
%40 = arith.xori %39, %cst_0 : tensor<1x2048xi1>
%41 = arith.andi %32, %40 : tensor<1x2048xi1>
%42 = arith.select %41, %36, %arg7 : tensor<1x2048xi1>, tensor<1x2048xf32>
scf.yield %42 : tensor<1x2048xf32>
}
%9 = "tt.reduce"(%8) <{axis = 1 : i32}> ({
^bb0(%arg6: f32, %arg7: f32):
%29 = arith.cmpf ogt, %arg6, %arg7 : f32
%30 = arith.cmpf une, %arg6, %arg6 : f32
%31 = arith.ori %29, %30 : i1
%32 = arith.select %31, %arg6, %arg7 : f32
tt.reduce.return %32 : f32
}) : (tensor<1x2048xf32>) -> tensor<1xf32>
%10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
%11 = tt.addptr %arg1, %1 : !tt.ptr<f32, 1>, i64
%12 = tt.splat %11 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
tt.store %12, %10 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>
%13 = arith.muli %1, %c50257_i64 : i64
%14 = tt.splat %13 : (i64) -> tensor<1x2048xi64>
%15 = tt.splat %arg0 : (!tt.ptr<bf16, 1>) -> tensor<1x2048x!tt.ptr<bf16, 1>>
%16 = tt.broadcast %10 : (tensor<1x1xf32>) -> tensor<1x2048xf32>
%17 = scf.for %arg6 = %c0_i32 to %c50257_i32 step %c2048_i32 iter_args(%arg7 = %cst_2) -> (tensor<1x2048xf32>) : i32 {
%29 = arith.extsi %arg6 : i32 to i64
%30 = tt.splat %29 : (i64) -> tensor<1x2048xi64>
%31 = arith.addi %30, %4 : tensor<1x2048xi64>
%32 = arith.cmpi slt, %31, %cst_1 : tensor<1x2048xi64>
%33 = arith.addi %31, %14 : tensor<1x2048xi64>
%34 = tt.addptr %15, %33 : tensor<1x2048x!tt.ptr<bf16, 1>>, tensor<1x2048xi64>
%35 = tt.load %34, %32, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x2048xbf16>
%36 = arith.extf %35 : tensor<1x2048xbf16> to tensor<1x2048xf32>
%37 = arith.subf %36, %16 : tensor<1x2048xf32>
%38 = math.exp %37 : tensor<1x2048xf32>
%39 = arith.addf %arg7, %38 : tensor<1x2048xf32>
%40 = arith.select %32, %39, %arg7 : tensor<1x2048xi1>, tensor<1x2048xf32>
scf.yield %40 : tensor<1x2048xf32>
}
%18 = "tt.reduce"(%17) <{axis = 1 : i32}> ({
^bb0(%arg6: f32, %arg7: f32):
%29 = arith.addf %arg6, %arg7 : f32
tt.reduce.return %29 : f32
}) : (tensor<1x2048xf32>) -> tensor<1xf32>
%19 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32>
%20 = tt.addptr %arg2, %1 : !tt.ptr<f32, 1>, i64
%21 = tt.splat %20 : (!tt.ptr<f32, 1>) -> tensor<1x1x!tt.ptr<f32, 1>>
tt.store %21, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32>
%22 = arith.muli %1, %c50257_i64 : i64
%23 = tt.splat %22 : (i64) -> tensor<1x2048xi64>
%24 = tt.splat %arg0 : (!tt.ptr<bf16, 1>) -> tensor<1x2048x!tt.ptr<bf16, 1>>
%25 = tt.broadcast %10 : (tensor<1x1xf32>) -> tensor<1x2048xf32>
%26 = math.log %19 : tensor<1x1xf32>
%27 = tt.broadcast %26 : (tensor<1x1xf32>) -> tensor<1x2048xf32>
%28 = tt.splat %arg3 : (!tt.ptr<bf16, 1>) -> tensor<1x2048x!tt.ptr<bf16, 1>>
scf.for %arg6 = %c0_i32 to %c50257_i32 step %c2048_i32 : i32 {
%29 = arith.extsi %arg6 : i32 to i64
%30 = tt.splat %29 : (i64) -> tensor<1x2048xi64>
%31 = arith.addi %30, %4 : tensor<1x2048xi64>
%32 = arith.cmpi slt, %31, %cst_1 : tensor<1x2048xi64>
%33 = arith.addi %31, %23 : tensor<1x2048xi64>
%34 = tt.addptr %24, %33 : tensor<1x2048x!tt.ptr<bf16, 1>>, tensor<1x2048xi64>
%35 = tt.load %34, %32, %cst {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<1x2048xbf16>
%36 = arith.extf %35 : tensor<1x2048xbf16> to tensor<1x2048xf32>
%37 = arith.subf %36, %25 : tensor<1x2048xf32>
%38 = arith.subf %37, %27 : tensor<1x2048xf32>
%39 = tt.addptr %28, %33 : tensor<1x2048x!tt.ptr<bf16, 1>>, tensor<1x2048xi64>
%40 = arith.truncf %38 : tensor<1x2048xf32> to tensor<1x2048xbf16>
tt.store %39, %40, %32 {cache = 1 : i32, evict = 1 : i32} : tensor<1x2048xbf16>
}
tt.return
}
}