Skip to main content

Mojo struct

SM100TensorAccumulatorTS

@register_passable(trivial) struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle(3), transpose_b: Bool = True, cta_group: Int = 1]

Fields

  • mbar (UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]):
  • phase (SIMD[uint32, 1]):

Implemented traits

AnyType, Copyable, ExplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

a_frag_size

alias a_frag_size = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 16), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 16), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * 16) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 16), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)

a_t

alias a_t = TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), BK, 16, num_softmax_threads]

ab_t

alias ab_t = UMMADescriptorTS[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), MMA_M=0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), MMA_N=BK, MMA_K=16, consumer_group_size=num_softmax_threads]

accum_t

alias accum_t = accum_type

b_offset

alias b_offset = MMAOperandOffsetFn()

b_t

alias b_t = MMASmemDescriptor

c_frag_size

alias c_frag_size = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)

c_t

alias c_t = TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), MMA_N, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), num_softmax_threads]

idesc

alias idesc = create[::DType,::DType,::DType,::IndexList[::Int()

MMA_K

alias MMA_K = 16

num_k_mmas

alias num_k_mmas = (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BK, "_mlir_value">, 16) + -1) if ((BK < 0) & ((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BK, "_mlir_value">, 16) == 0) ^ True)) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BK, "_mlir_value">, 16)

num_m_blocks_per_warp

alias num_m_blocks_per_warp = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, 2), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)

num_m_mmas

alias num_m_mmas = 0 if (MMA_M == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 0), {1}, MMA_M), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 0), {1}, MMA_M), "_mlir_value">) == 0) ^ True) & ((BM < 0) ^ (MMA_M < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BM, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, 0), {1}, MMA_M), "_mlir_value">)

num_n_mmas

alias num_n_mmas = 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int BN, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">)

operand_t

alias operand_t = operand_type

smem_ptr_t

alias smem_ptr_t = UnsafePointer[SIMD[operand_type, 1], address_space=AddressSpace(3)]

Methods

__init__

__init__(smem: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]) -> Self

check_constraints

static check_constraints()

init

init(self)

a_mma_descriptor

static a_mma_descriptor(a_tmem: SIMD[uint32, 1]) -> TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), BK, 16, num_softmax_threads]

Returns:

TMemOperand

b_mma_descriptor

static b_mma_descriptor[dtype_b: DType](p_b: UnsafePointer[SIMD[dtype_b, 1], address_space=AddressSpace(3)]) -> MMASmemDescriptor

Returns:

MMASmemDescriptor

mma

mma(self, a: TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), BK, 16, num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) ^ (BM < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 0), {1}, cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)}))), "_mlir_value">), MMA_N, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BM, "_mlir_value">, 2), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0 if (MMA_N == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int BN, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">, 0), {1}, MMA_N), "_mlir_value">), num_softmax_threads], c_scale: SIMD[uint32, 1])

wait

wait(mut self, idx: SIMD[uint32, 1])

wait_for_mma

wait_for_mma(mut self)

Wait for the mma to be complete.

wait_for_tmem

wait_for_tmem(mut self)

Wait for the output and A tmem to be ready.

tmem_arrive

tmem_arrive(self)

Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.

Was this page helpful?