module { tt.func public @triton__0d1d2d3d4de5(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i64 {tt.divisibility = 16 : i32, tt.max_divisibility = 16 : i32}, %arg5: i64) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<1x2048xbf16> %c50257_i64 = arith.constant 50257 : i64 %cst_0 = arith.constant dense : tensor<1x2048xi1> %c50257_i32 = arith.constant 50257 : i32 %c2048_i32 = arith.constant 2048 : i32 %c0_i32 = arith.constant 0 : i32 %cst_1 = arith.constant dense<50257> : tensor<1x2048xi64> %cst_2 = arith.constant dense<0.000000e+00> : tensor<1x2048xf32> %cst_3 = arith.constant dense<0xFF800000> : tensor<1x2048xf32> %0 = tt.get_program_id x : i32 %1 = arith.extsi %0 : i32 to i64 %2 = tt.make_range {end = 2048 : i32, start = 0 : i32} : tensor<2048xi32> %3 = tt.expand_dims %2 {axis = 0 : i32} : (tensor<2048xi32>) -> tensor<1x2048xi32> %4 = arith.extsi %3 : tensor<1x2048xi32> to tensor<1x2048xi64> %5 = arith.muli %1, %c50257_i64 : i64 %6 = tt.splat %5 : (i64) -> tensor<1x2048xi64> %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1x2048x!tt.ptr> %8 = scf.for %arg6 = %c0_i32 to %c50257_i32 step %c2048_i32 iter_args(%arg7 = %cst_3) -> (tensor<1x2048xf32>) : i32 { %29 = arith.extsi %arg6 : i32 to i64 %30 = tt.splat %29 : (i64) -> tensor<1x2048xi64> %31 = arith.addi %30, %4 : tensor<1x2048xi64> %32 = arith.cmpi slt, %31, %cst_1 : tensor<1x2048xi64> %33 = arith.addi %31, %6 : tensor<1x2048xi64> %34 = tt.addptr %7, %33 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> %35 = tt.load %34, %32, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x2048xbf16> %36 = arith.extf %35 : tensor<1x2048xbf16> to tensor<1x2048xf32> %37 = arith.cmpf ogt, %arg7, %36 : tensor<1x2048xf32> %38 = arith.cmpf une, %arg7, %arg7 : tensor<1x2048xf32> %39 = arith.ori %37, %38 : tensor<1x2048xi1> %40 = arith.xori %39, %cst_0 : tensor<1x2048xi1> %41 = arith.andi %32, %40 : tensor<1x2048xi1> %42 = %41, %36, %arg7 : tensor<1x2048xi1>, tensor<1x2048xf32> scf.yield %42 : tensor<1x2048xf32> } %9 = "tt.reduce"(%8) <{axis = 1 : i32}> ({ ^bb0(%arg6: f32, %arg7: f32): %29 = arith.cmpf ogt, %arg6, %arg7 : f32 %30 = arith.cmpf une, %arg6, %arg6 : f32 %31 = arith.ori %29, %30 : i1 %32 = %31, %arg6, %arg7 : f32 tt.reduce.return %32 : f32 }) : (tensor<1x2048xf32>) -> tensor<1xf32> %10 = tt.expand_dims %9 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32> %11 = tt.addptr %arg1, %1 : !tt.ptr, i64 %12 = tt.splat %11 : (!tt.ptr) -> tensor<1x1x!tt.ptr> %12, %10 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32> %13 = arith.muli %1, %c50257_i64 : i64 %14 = tt.splat %13 : (i64) -> tensor<1x2048xi64> %15 = tt.splat %arg0 : (!tt.ptr) -> tensor<1x2048x!tt.ptr> %16 = tt.broadcast %10 : (tensor<1x1xf32>) -> tensor<1x2048xf32> %17 = scf.for %arg6 = %c0_i32 to %c50257_i32 step %c2048_i32 iter_args(%arg7 = %cst_2) -> (tensor<1x2048xf32>) : i32 { %29 = arith.extsi %arg6 : i32 to i64 %30 = tt.splat %29 : (i64) -> tensor<1x2048xi64> %31 = arith.addi %30, %4 : tensor<1x2048xi64> %32 = arith.cmpi slt, %31, %cst_1 : tensor<1x2048xi64> %33 = arith.addi %31, %14 : tensor<1x2048xi64> %34 = tt.addptr %15, %33 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> %35 = tt.load %34, %32, %cst {cache = 1 : i32, evict = 3 : i32, isVolatile = false} : tensor<1x2048xbf16> %36 = arith.extf %35 : tensor<1x2048xbf16> to tensor<1x2048xf32> %37 = arith.subf %36, %16 : tensor<1x2048xf32> %38 = math.exp %37 : tensor<1x2048xf32> %39 = arith.addf %arg7, %38 : tensor<1x2048xf32> %40 = %32, %39, %arg7 : tensor<1x2048xi1>, tensor<1x2048xf32> scf.yield %40 : tensor<1x2048xf32> } %18 = "tt.reduce"(%17) <{axis = 1 : i32}> ({ ^bb0(%arg6: f32, %arg7: f32): %29 = arith.addf %arg6, %arg7 : f32 tt.reduce.return %29 : f32 }) : (tensor<1x2048xf32>) -> tensor<1xf32> %19 = tt.expand_dims %18 {axis = 1 : i32} : (tensor<1xf32>) -> tensor<1x1xf32> %20 = tt.addptr %arg2, %1 : !tt.ptr, i64 %21 = tt.splat %20 : (!tt.ptr) -> tensor<1x1x!tt.ptr> %21, %19 {cache = 1 : i32, evict = 1 : i32} : tensor<1x1xf32> %22 = arith.muli %1, %c50257_i64 : i64 %23 = tt.splat %22 : (i64) -> tensor<1x2048xi64> %24 = tt.splat %arg0 : (!tt.ptr) -> tensor<1x2048x!tt.ptr> %25 = tt.broadcast %10 : (tensor<1x1xf32>) -> tensor<1x2048xf32> %26 = math.log %19 : tensor<1x1xf32> %27 = tt.broadcast %26 : (tensor<1x1xf32>) -> tensor<1x2048xf32> %28 = tt.splat %arg3 : (!tt.ptr) -> tensor<1x2048x!tt.ptr> scf.for %arg6 = %c0_i32 to %c50257_i32 step %c2048_i32 : i32 { %29 = arith.extsi %arg6 : i32 to i64 %30 = tt.splat %29 : (i64) -> tensor<1x2048xi64> %31 = arith.addi %30, %4 : tensor<1x2048xi64> %32 = arith.cmpi slt, %31, %cst_1 : tensor<1x2048xi64> %33 = arith.addi %31, %23 : tensor<1x2048xi64> %34 = tt.addptr %24, %33 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> %35 = tt.load %34, %32, %cst {cache = 1 : i32, evict = 2 : i32, isVolatile = false} : tensor<1x2048xbf16> %36 = arith.extf %35 : tensor<1x2048xbf16> to tensor<1x2048xf32> %37 = arith.subf %36, %25 : tensor<1x2048xf32> %38 = arith.subf %37, %27 : tensor<1x2048xf32> %39 = tt.addptr %28, %33 : tensor<1x2048x!tt.ptr>, tensor<1x2048xi64> %40 = arith.truncf %38 : tensor<1x2048xf32> to tensor<1x2048xbf16> %39, %40, %32 {cache = 1 : i32, evict = 1 : i32} : tensor<1x2048xbf16> } tt.return } }