Mojo struct
RegisterAccumulatorLayout
@register_passable(trivial)
struct RegisterAccumulatorLayout[MMA_M: Int, MMA_N: Int, num_m_mmas: Int, num_n_mmas: Int, consumer_group_size: Int, *, frag_simdwidth: Int = 2]
Implemented traits
AnyType
,
Copyable
,
ExplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
element_layout
alias element_layout = row_major(1, frag_simdwidth)
frag_size
alias frag_size = 0 if (consumer_group_size == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (consumer_group_size < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">)
num_row_blocks_per_mma
alias num_row_blocks_per_mma = 2
rows_of_frags_layout
alias rows_of_frags_layout = row_major((num_m_mmas * num_n_mmas), 0 if (consumer_group_size == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (consumer_group_size < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">))
vec_output_layout
alias vec_output_layout = Layout(IntTuple(IntTuple(2, num_m_mmas), IntTuple(0 if ((frag_simdwidth * 2) == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">)})), "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int frag_simdwidth, "_mlir_value">, 2), 0), {1}, {_mlir_value = mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int frag_simdwidth, "_mlir_value">, 2)}), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">)})), "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int frag_simdwidth, "_mlir_value">, 2), 0), {1}, {_mlir_value = mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int frag_simdwidth, "_mlir_value">, 2)}), "_mlir_value">) == 0) ^ True) & (((frag_simdwidth * 2) < 0) ^ (0 if (consumer_group_size == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (consumer_group_size < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">)})), "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int frag_simdwidth, "_mlir_value">, 2), 0), {1}, {_mlir_value = mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int frag_simdwidth, "_mlir_value">, 2)}), "_mlir_value">), num_n_mmas), Tuple()), IntTuple(IntTuple(frag_simdwidth, 0 if (consumer_group_size == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (consumer_group_size < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">)), IntTuple((frag_simdwidth * 2), (0 if (consumer_group_size == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (consumer_group_size < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int consumer_group_size, "_mlir_value">, 0), {1}, consumer_group_size), "_mlir_value">) * num_m_mmas)), Tuple()))
Methods
description
static description() -> RegisterAccumulatorDescription
Returns:
RegisterAccumulatorDescription
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!