IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MXFP4MoERoutedMatmul

struct MXFP4MoERoutedMatmul[BM: Int = 64, BN: Int = 64, BK_ELEMS: Int = 256, num_warps_m: Int = 2, num_warps_n: Int = 2, topk: Int = 1, INPUT_ROW_MODE: InputRowMode = InputRowMode.TOKEN_ID, enable_swizzle: Bool = True]

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BK_BYTES​

comptime BK_BYTES = (BK_ELEMS // 2)

BK_SCALES​

comptime BK_SCALES = (BK_ELEMS // 32)

C_FRAG_SIZE​

comptime C_FRAG_SIZE = ((MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_M * MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_N) // WARP_SIZE)

FRAG_W_BYTES​

comptime FRAG_W_BYTES = 16

MMA_K_BYTES​

comptime MMA_K_BYTES = 64

MMA_M​

comptime MMA_M = 16

MMA_N​

comptime MMA_N = 16

num_k_tiles_per_BK​

comptime num_k_tiles_per_BK = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].BK_BYTES // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_K_BYTES)

num_m_mmas​

comptime num_m_mmas = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].WM // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_M)

num_n_mmas​

comptime num_n_mmas = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].WN // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_N)

num_scale_packs_per_BK​

comptime num_scale_packs_per_BK = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].num_k_tiles_per_BK // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].pack_K)

num_threads​

comptime num_threads = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].num_warps * WARP_SIZE)

num_warps​

comptime num_warps = (num_warps_m * num_warps_n)

pack_K​

comptime pack_K = 2

sort_block_m​

comptime sort_block_m = BM

WM​

comptime WM = (BM // num_warps_m)

WN​

comptime WN = (BN // num_warps_n)

Methods​

run​

static def run[K_BYTES: Int, K_SCALES: Int, N: Int, N_padded_scale: Int](c: TileTensor[address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a_tt: TileTensor[DType.uint8, address_space=a_tt.address_space, linear_idx_type=a_tt.linear_idx_type, element_size=a_tt.element_size], b_pre_tt: TileTensor[DType.uint8, address_space=b_pre_tt.address_space, linear_idx_type=b_pre_tt.linear_idx_type, element_size=b_pre_tt.element_size], sfa_pre_tt: TileTensor[DType.uint8, address_space=sfa_pre_tt.address_space, linear_idx_type=sfa_pre_tt.linear_idx_type, element_size=sfa_pre_tt.element_size], sfb_pre_tt: TileTensor[DType.uint8, address_space=sfb_pre_tt.address_space, linear_idx_type=sfb_pre_tt.linear_idx_type, element_size=sfb_pre_tt.element_size], sorted_token_ids: TileTensor[DType.uint32, address_space=sorted_token_ids.address_space, linear_idx_type=sorted_token_ids.linear_idx_type, element_size=sorted_token_ids.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], num_tokens: Int, size_expert_ids: Int)