Mojo struct
SM100TensorAccumulatorTS
@register_passable(trivial)
struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle(3), transpose_b: Bool = True, cta_group: Int = 1]
Fields
- mbar (
UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]
): - phase (
SIMD[uint32, 1]
):
Implemented traits
AnyType
,
Copyable
,
ExplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
a_frag_size
alias a_frag_size = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 16), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 16), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * 16) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 16), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)
a_t
alias a_t = TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), BK, 16, num_softmax_threads]
ab_t
alias ab_t = UMMADescriptorTS[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), MMA_M=0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), MMA_N=BK, MMA_K=16, consumer_group_size=num_softmax_threads]
accum_t
alias accum_t = accum_type
b_offset
alias b_offset = MMAOperandOffsetFn()
b_t
alias b_t = MMASmemDescriptor
c_frag_size
alias c_frag_size = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)
c_t
alias c_t = TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), MMA_N, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), num_softmax_threads]
idesc
alias idesc = create[::DType,::DType,::DType,::IndexList[::Int()
MMA_K
alias MMA_K = 16
num_k_mmas
alias num_k_mmas = (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BK, "_mlir_value">, 16) + -1) if ((BK < 0) & ((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BK, "_mlir_value">, 16) == 0) ^ True)) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BK, "_mlir_value">, 16)
num_m_blocks_per_warp
alias num_m_blocks_per_warp = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)
num_m_mmas
alias num_m_mmas = 0 if (MMA_M == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 0), {1}, MMA_M), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 0), {1}, MMA_M), "_mlir_value">) == 0) ^ True) & ((BM < 0) ^ (MMA_M < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 0), {1}, MMA_M), "_mlir_value">)
num_n_mmas
alias num_n_mmas = 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">)
operand_t
alias operand_t = operand_type
smem_ptr_t
alias smem_ptr_t = UnsafePointer[SIMD[operand_type, 1], address_space=AddressSpace(3)]
Methods
__init__
__init__(smem: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]) -> Self
check_constraints
static check_constraints()
init
init(self)
a_mma_descriptor
static a_mma_descriptor(a_tmem: SIMD[uint32, 1]) -> TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), BK, 16, num_softmax_threads]
Returns:
b_mma_descriptor
static b_mma_descriptor[dtype_b: DType](p_b: UnsafePointer[SIMD[dtype_b, 1], address_space=AddressSpace(3)]) -> MMASmemDescriptor
Returns:
mma
mma(self, a: TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), BK, 16, num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), MMA_N, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), num_softmax_threads], c_scale: SIMD[uint32, 1])
wait
wait(mut self, idx: SIMD[uint32, 1])
wait_for_mma
wait_for_mma(mut self)
Wait for the mma to be complete.
wait_for_tmem
wait_for_tmem(mut self)
Wait for the output
and A
tmem to be ready.
tmem_arrive
tmem_arrive(self)
Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!