Mojo function
multistage_mma_q
multistage_mma_q[BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_threads: Int, num_pipeline_stages: Int, transpose_b: Bool, group_size: Int, pack_factor: Int, c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, a_smem_layout: Layout, b_type: DType, b_layout: Layout, b_smem_layout: Layout, scales_type: DType, scales_layout: Layout, scales_smem_layout: Layout, /, *, swizzle_a: Bool = True, static_num_iters: Int = -1, prefetch_init: Bool = True, continue_prefetch_b: Bool = False, transpose_b_next: Bool = False, b_next_gmem_layout: Layout = Layout(), b_next_smem_layout: Layout = Layout(), next_op_b_iter_alignment: Int = align_of[b_type]()](c: LayoutTensor[c_type, c_layout, c.origin, address_space=AddressSpace.LOCAL, element_layout=c.element_layout, layout_int_type=c.layout_int_type, linear_idx_type=c.linear_idx_type, masked=c.masked, alignment=c.alignment], a_iter_arg: LayoutTensorIter[a_iter_arg.dtype, a_layout, a_iter_arg.origin, address_space=a_iter_arg.address_space, alignment=a_iter_arg.alignment, circular=a_iter_arg.circular, axis=a_iter_arg.axis, layout_int_type=a_iter_arg.layout_int_type, linear_idx_type=a_iter_arg.linear_idx_type, masked=a_iter_arg.masked], b_iter_arg: LayoutTensorIter[b_type, b_layout, b_iter_arg.origin, address_space=b_iter_arg.address_space, alignment=b_iter_arg.alignment, circular=b_iter_arg.circular, axis=b_iter_arg.axis, layout_int_type=b_iter_arg.layout_int_type, linear_idx_type=b_iter_arg.linear_idx_type, masked=b_iter_arg.masked], a_smem_iter_arg: LayoutTensorIter[a_type, a_smem_layout, a_smem_iter_arg.origin, address_space=AddressSpace.SHARED, alignment=a_smem_iter_arg.alignment, circular=a_smem_iter_arg.circular, axis=a_smem_iter_arg.axis, layout_int_type=a_smem_iter_arg.layout_int_type, linear_idx_type=a_smem_iter_arg.linear_idx_type, masked=a_smem_iter_arg.masked], mut b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, b_smem_iter.origin, address_space=AddressSpace.SHARED, alignment=b_smem_iter.alignment, circular=b_smem_iter.circular, axis=b_smem_iter.axis, layout_int_type=b_smem_iter.layout_int_type, linear_idx_type=b_smem_iter.linear_idx_type, masked=b_smem_iter.masked], scales_smem_iter_arg: LayoutTensorIter[scales_type, scales_smem_layout, scales_smem_iter_arg.origin, address_space=AddressSpace.SHARED, alignment=scales_smem_iter_arg.alignment, circular=scales_smem_iter_arg.circular, axis=scales_smem_iter_arg.axis, layout_int_type=scales_smem_iter_arg.layout_int_type, linear_idx_type=scales_smem_iter_arg.linear_idx_type, masked=scales_smem_iter_arg.masked], scales_iter_arg: LayoutTensorIter[scales_type, scales_layout, scales_iter_arg.origin, address_space=scales_iter_arg.address_space, alignment=scales_iter_arg.alignment, circular=scales_iter_arg.circular, axis=scales_iter_arg.axis, layout_int_type=scales_iter_arg.layout_int_type, linear_idx_type=scales_iter_arg.linear_idx_type, masked=scales_iter_arg.masked], num_iters: Int, /, *, num_b_rows: Optional[Int] = None)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!