Mojo struct
TMemOperand
@register_passable(trivial)
struct TMemOperand[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, num_softmax_threads: Int]
Fields
- tmem_addr (
UInt32
):
Implemented traits
AnyType
,
ExplicitlyCopyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
,
WriteableMMAOperandDescriptor
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
frag_size
alias frag_size = 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)
reg_layout
alias reg_layout = RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads]
reg_tile_t
alias reg_tile_t = LayoutTensor[dtype, Layout.__init__(IntTuple.__init__[{}](IntTuple[{}](2, num_m_mmas), IntTuple[{}]((div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) + -1) if ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) & ((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4), num_n_mmas), Tuple[]()), IntTuple.__init__[{}](IntTuple[{}](2, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), IntTuple[{}](4, (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) * num_m_mmas)), Tuple[]())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=Layout.row_major(1, 2)]
vec_output_layout
alias vec_output_layout = Layout.__init__(IntTuple.__init__[{}](IntTuple[{}](2, num_m_mmas), IntTuple[{}]((div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) + -1) if ((0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) < 0) & ((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4) == 0) ^ True)) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {0}, cond(and(ne(rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), 0), xor(lt(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), 0), lt(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0))), {_mlir_value = add(div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">), -1)}, {_mlir_value = div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)})), "_mlir_value">, 4), num_n_mmas), Tuple[]()), IntTuple.__init__[{}](IntTuple[{}](2, 0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">)), IntTuple[{}](4, (0 if (num_softmax_threads == 0) else (div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) + -1) if (((rem_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else div_s(mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_M, "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int MMA_N, "_mlir_value">), #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int num_softmax_threads, "_mlir_value">, 0), {1}, num_softmax_threads), "_mlir_value">) * num_m_mmas)), Tuple[]()))
Methods
__init__
__init__(tmem_addr: UInt32) -> Self
offset
copy_from
copy_from[src_type: DType, src_layout: Layout, src_element_layout: Layout, //](self, src: LayoutTensor[src_type, src_layout, MutableAnyOrigin, address_space=AddressSpace(5), element_layout=src_element_layout])
copy_to
copy_to[dst_type: DType, dst_layout: Layout, dst_element_layout: Layout, //](self, dst: LayoutTensor[dst_type, dst_layout, MutableAnyOrigin, address_space=AddressSpace(5), element_layout=dst_element_layout])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!