Mojo struct
MXFP4MoERoutedMatmul
struct MXFP4MoERoutedMatmul[BM: Int = 64, BN: Int = 64, BK_ELEMS: Int = 256, num_warps_m: Int = 2, num_warps_n: Int = 2, topk: Int = 1, INPUT_ROW_MODE: InputRowMode = InputRowMode.TOKEN_ID, enable_swizzle: Bool = True]
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
BK_BYTESβ
comptime BK_BYTES = (BK_ELEMS // 2)
BK_SCALESβ
comptime BK_SCALES = (BK_ELEMS // 32)
C_FRAG_SIZEβ
comptime C_FRAG_SIZE = ((MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_M * MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_N) // WARP_SIZE)
FRAG_W_BYTESβ
comptime FRAG_W_BYTES = 16
MMA_K_BYTESβ
comptime MMA_K_BYTES = 64
MMA_Mβ
comptime MMA_M = 16
MMA_Nβ
comptime MMA_N = 16
num_k_tiles_per_BKβ
comptime num_k_tiles_per_BK = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].BK_BYTES // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_K_BYTES)
num_m_mmasβ
comptime num_m_mmas = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].WM // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_M)
num_n_mmasβ
comptime num_n_mmas = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].WN // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].MMA_N)
num_scale_packs_per_BKβ
comptime num_scale_packs_per_BK = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].num_k_tiles_per_BK // MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].pack_K)
num_threadsβ
comptime num_threads = (MXFP4MoERoutedMatmul[BM, BN, BK_ELEMS, num_warps_m, num_warps_n, topk, INPUT_ROW_MODE, enable_swizzle].num_warps * WARP_SIZE)
num_warpsβ
comptime num_warps = (num_warps_m * num_warps_n)
pack_Kβ
comptime pack_K = 2
sort_block_mβ
comptime sort_block_m = BM
WMβ
comptime WM = (BM // num_warps_m)
WNβ
comptime WN = (BN // num_warps_n)
Methodsβ
runβ
static run[K_BYTES: Int, K_SCALES: Int, N: Int, N_padded_scale: Int](c: TileTensor[address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a_tt: TileTensor[DType.uint8, address_space=a_tt.address_space, linear_idx_type=a_tt.linear_idx_type, element_size=a_tt.element_size], b_pre_tt: TileTensor[DType.uint8, address_space=b_pre_tt.address_space, linear_idx_type=b_pre_tt.linear_idx_type, element_size=b_pre_tt.element_size], sfa_pre_tt: TileTensor[DType.uint8, address_space=sfa_pre_tt.address_space, linear_idx_type=sfa_pre_tt.linear_idx_type, element_size=sfa_pre_tt.element_size], sfb_pre_tt: TileTensor[DType.uint8, address_space=sfb_pre_tt.address_space, linear_idx_type=sfb_pre_tt.linear_idx_type, element_size=sfb_pre_tt.element_size], sorted_token_ids: TileTensor[DType.uint32, address_space=sorted_token_ids.address_space, linear_idx_type=sorted_token_ids.linear_idx_type, element_size=sorted_token_ids.element_size], expert_ids: TileTensor[DType.int32, address_space=expert_ids.address_space, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], num_tokens: Int, size_expert_ids: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!