For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Struct_grouped_matmul_block_scaled_mxfp4
struct Struct_grouped_matmul_block_scaled_mxfp4[preshuffled_b: Bool = False]
MOGG wrapper for grouped block-scaled matrix multiplication.
Provides graph compiler integration for block-scaled grouped matmul operations used in Mixture of Experts (MoE) layers on AMD GPUs.
Parametersβ
- βpreshuffled_b (
Bool): When True, dispatches tomxfp4_grouped_matmul_amd_prebwhich expects B in the 5D preshuffled layout fromShuffler.preshuffle_b_5d(typically produced by the model's weight adapter at load time, e.g. Kimi K2.5). When False (default), dispatches to the densemxfp4_grouped_matmul_amdkernel that reads B row-major. The caller is responsible for preparing B in the matching layout.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static def execute[c_type: DType, //, target: StringSlice[StaticConstantOrigin]](c: ManagedTensorSlice[Output, static_spec=c.static_spec], a: ManagedTensorSlice[Input, static_spec=a.static_spec], b: ManagedTensorSlice[Input, static_spec=b.static_spec], a_scales: ManagedTensorSlice[Input, static_spec=a_scales.static_spec], b_scales: ManagedTensorSlice[Input, static_spec=b_scales.static_spec], expert_start_indices: ManagedTensorSlice[Input, static_spec=expert_start_indices.static_spec], expert_ids: ManagedTensorSlice[Input, static_spec=expert_ids.static_spec], max_num_tokens_per_expert: UInt32, num_active_experts: UInt32, estimated_total_m: UInt32, context: DeviceContext)
Executes grouped block-scaled matrix multiplication.
Computes C = A @ B^T for multiple expert groups where A and B are block-scaled (e.g. MXFP4: 4-bit floating point packed as uint8).
Parameters:
- βc_type (
DType): The output tensor data type. - βtarget (
StringSlice[StaticConstantOrigin]): The target GPU device.
Args:
- βc (
ManagedTensorSlice[Output, static_spec=c.static_spec]): The output tensor of shape (total_tokens, N). - βa (
ManagedTensorSlice[Input, static_spec=a.static_spec]): The input tensor of shape (total_tokens, K // 2). - βb (
ManagedTensorSlice[Input, static_spec=b.static_spec]): The weight tensor of shape (num_experts, N, K // 2). - βa_scales (
ManagedTensorSlice[Input, static_spec=a_scales.static_spec]): The A scale factors in 2D layout. - βb_scales (
ManagedTensorSlice[Input, static_spec=b_scales.static_spec]): The B scale factors in 3D layout. - βexpert_start_indices (
ManagedTensorSlice[Input, static_spec=expert_start_indices.static_spec]): The starting token index for each expert. - βexpert_ids (
ManagedTensorSlice[Input, static_spec=expert_ids.static_spec]): The expert ID for each group. - βmax_num_tokens_per_expert (
UInt32): The maximum token count for any expert. - βnum_active_experts (
UInt32): The number of active experts. - βestimated_total_m (
UInt32): Estimated total received tokens for this GPU, used by the preb dispatcher to pick the persistent vs direct kernel path. Pass 0 to default to persistent. Ignored whenpreshuffled_b == False. - βcontext (
DeviceContext): The device context pointer.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!