Mojo struct
TMemAccumulator
@register_passable(trivial)
struct TMemAccumulator[dtype_: DType, MMA_M: Int, MMA_N: Int, num_m_mmas: Int, num_n_mmas: Int, num_consumer_threads: Int]
Fields
- tmem_addr (
SIMD[uint32, 1]
):
Implemented traits
AccumulatorTile
,
AnyType
,
Copyable
,
Movable
,
UnknownDestructibility
Aliases
dtype
alias dtype = dtype_
element_layout
alias element_layout = row_major(1, 2)
frag_size
alias frag_size = 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)
layout_t
alias layout_t = RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_consumer_threads]
rows_of_frags_layout
alias rows_of_frags_layout = row_major((num_m_mmas * num_n_mmas), 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">))
tmem_addr_t
alias tmem_addr_t = SIMD[uint32, 1]
vec_output_layout
alias vec_output_layout = __init__[::Origin[::Bool(__init__[::Origin[::Bool(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) + -1) if ((0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) < 0) & ((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4), num_n_mmas), Tuple()), __init__[::Origin[::Bool(IntTuple(2, 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)), IntTuple(4, (0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) * num_m_mmas)), Tuple()))
Methods
__init__
__init__(tmem_addr: SIMD[uint32, 1]) -> Self
check_constraints
static check_constraints()
offset
offset[m_mma: Int, n_mma: Int](self) -> SIMD[uint32, 1]
rows_of_frags
static rows_of_frags(src: LayoutTensor[dtype_, __init__[::Origin[::Bool(__init__[::Origin[::Bool(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) + -1) if ((0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4), num_n_mmas), Tuple()), __init__[::Origin[::Bool(IntTuple(2, 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)), IntTuple(4, (0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)]) -> LayoutTensor[dtype_, row_major((num_m_mmas * num_n_mmas), 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)), MutableAnyOrigin, address_space=AddressSpace(5)]
allocate_register_tile
allocate_register_tile(self) -> LayoutTensor[dtype_, __init__[::Origin[::Bool(__init__[::Origin[::Bool(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) + -1) if ((0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4), num_n_mmas), Tuple()), __init__[::Origin[::Bool(IntTuple(2, 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)), IntTuple(4, (0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)]
copy_from
copy_from(self, src: LayoutTensor[dtype_, __init__[::Origin[::Bool(__init__[::Origin[::Bool(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) + -1) if ((0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4), num_n_mmas), Tuple()), __init__[::Origin[::Bool(IntTuple(2, 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)), IntTuple(4, (0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)])
copy_to
copy_to(self, dst: LayoutTensor[dtype_, __init__[::Origin[::Bool(__init__[::Origin[::Bool(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) + -1) if ((0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0))), {value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">), -1)}, {value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)})), "value">, 4), num_n_mmas), Tuple()), __init__[::Origin[::Bool(IntTuple(2, 0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">)), IntTuple(4, (0 if (num_consumer_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_consumer_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_consumer_threads, "value">, 0), {1}, num_consumer_threads), "value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!