Skip to main content

Mojo struct

TMemAccumulator

@register_passable(trivial) struct TMemAccumulator[dtype_: DType, MMA_M: Int, MMA_N: Int, num_m_mmas: Int, num_n_mmas: Int, num_softmax_threads: Int]

Fields

  • tmem_addr (SIMD[uint32, 1]):

Implemented traits

AccumulatorTile, AnyType, Copyable, ExplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

dtype

alias dtype = dtype_

element_layout

alias element_layout = row_major(1, 2)

frag_size

alias frag_size = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)

layout_t

alias layout_t = RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads]

rows_of_frags_layout

alias rows_of_frags_layout = row_major((num_m_mmas * num_n_mmas), 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">))

vec_output_layout

alias vec_output_layout = Layout(IntTuple(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) + -1) if ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) & ((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4), num_n_mmas), Tuple()), IntTuple(IntTuple(2, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), IntTuple(4, (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) * num_m_mmas)), Tuple()))

Methods

__init__

__init__(tmem_addr: SIMD[uint32, 1]) -> Self

__getitem__

__getitem__(self, i: SIMD[uint32, 1]) -> Self

check_constraints

static check_constraints()

offset

offset[m_mma: Int, n_mma: Int](self) -> SIMD[uint32, 1]

Returns:

SIMD

rows_of_frags

static rows_of_frags(src: LayoutTensor[dtype_, Layout(IntTuple(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) + -1) if ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4), num_n_mmas), Tuple()), IntTuple(IntTuple(2, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), IntTuple(4, (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)]) -> LayoutTensor[dtype_, row_major((num_m_mmas * num_n_mmas), 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), MutableAnyOrigin, address_space=AddressSpace(5)]

Returns:

LayoutTensor

allocate_register_tile

static allocate_register_tile() -> LayoutTensor[dtype_, Layout(IntTuple(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) + -1) if ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4), num_n_mmas), Tuple()), IntTuple(IntTuple(2, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), IntTuple(4, (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)]

Returns:

LayoutTensor

copy_from

copy_from(self, src: LayoutTensor[dtype_, Layout(IntTuple(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) + -1) if ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4), num_n_mmas), Tuple()), IntTuple(IntTuple(2, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), IntTuple(4, (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)])

copy_to

copy_to(self, dst: LayoutTensor[dtype_, Layout(IntTuple(IntTuple(2, num_m_mmas), IntTuple((div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) + -1) if ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) & ((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4), num_n_mmas), Tuple()), IntTuple(IntTuple(2, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), IntTuple(4, (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_M, "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int MMA_N, "_mlir_value">), #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) * num_m_mmas)), Tuple())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=row_major(1, 2)])

Was this page helpful?