Skip to main content

Mojo struct

MXFP4MoERoutedMatmul

struct MXFP4MoERoutedMatmul[BM: Int = 64, BN: Int = 64, BK_ELEMS: Int = 256, num_warps_m: Int = 2, num_warps_n: Int = 2, topk: Int = 1, INPUT_ROW_MODE: InputRowMode = InputRowMode.TOKEN_ID, enable_swizzle: Bool = True]

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

BK_BYTES​

comptime BK_BYTES = (BK_ELEMS // 2)

BK_SCALES​

comptime BK_SCALES = (BK_ELEMS // 32)

C_FRAG_SIZE​

comptime C_FRAG_SIZE = ((MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_M * MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_N) // WARP_SIZE)

FRAG_W_BYTES​

comptime FRAG_W_BYTES = 16

MMA_K_BYTES​

comptime MMA_K_BYTES = 64

MMA_M​

comptime MMA_M = 16

MMA_N​

comptime MMA_N = 16

num_k_tiles_per_BK​

comptime num_k_tiles_per_BK = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].BK_BYTES // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_K_BYTES)

num_m_mmas​

comptime num_m_mmas = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].WM // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_M)

num_n_mmas​

comptime num_n_mmas = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].WN // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_N)

num_scale_packs_per_BK​

comptime num_scale_packs_per_BK = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].num_k_tiles_per_BK // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].pack_K)

num_threads​

comptime num_threads = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].num_warps * WARP_SIZE)

num_warps​

comptime num_warps = (num_warps_m * num_warps_n)

pack_K​

comptime pack_K = 2

sort_block_m​

comptime sort_block_m = BM

WM​

comptime WM = (BM // num_warps_m)

WN​

comptime WN = (BN // num_warps_n)

Methods​

run​

static run[K_BYTES: Int, K_SCALES: Int, N: Int, N_padded_scale: Int](c: TileTensor[address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a_tt: TileTensor[DType.uint8, address_space=a_tt.address_space, linear_idx_type=a_tt.linear_idx_type, element_size=a_tt.element_size], b_pre_tt: TileTensor[DType.uint8, address_space=b_pre_tt.address_space, linear_idx_type=b_pre_tt.linear_idx_type, element_size=b_pre_tt.element_size], sfa_pre_tt: TileTensor[DType.uint8, address_space=sfa_pre_tt.address_space, linear_idx_type=sfa_pre_tt.linear_idx_type, element_size=sfa_pre_tt.element_size], sfb_pre_tt: TileTensor[DType.uint8, address_space=sfb_pre_tt.address_space, linear_idx_type=sfb_pre_tt.linear_idx_type, element_size=sfb_pre_tt.element_size], sorted_token_ids: TileTensor[DType.uint32, address_space=sorted_token_ids.address_space, linear_idx_type=sorted_token_ids.linear_idx_type, element_size=sorted_token_ids.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], num_tokens: Int, size_expert_ids: Int)