Skip to main content

Mojo function

mma

mma[MMA_M: Int, MMA_N: Int, MMA_K: Int, transpose_b: Bool, k_group_size: Int, config: MHAConfig, prefetch_function: fn[Int]() capturing -> None, swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]({:i1 0, 1}), swap_a_b: Bool = False, num_iters: Int = 1, token_gen: Bool = False](c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mut a_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], a_smem_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], mut b_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], b_smem_iter: LayoutTensorIter[dtype, layout, origin, address_space=AddressSpace(3), alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], num_b_rows: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}))

Was this page helpful?