IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Struct_grouped_matmul_block_scaled_mxfp4

struct Struct_grouped_matmul_block_scaled_mxfp4[preshuffled_b: Bool = False]

MOGG wrapper for grouped block-scaled matrix multiplication.

Provides graph compiler integration for block-scaled grouped matmul operations used in Mixture of Experts (MoE) layers on AMD GPUs.

Parameters​

  • ​preshuffled_b (Bool): When True, dispatches to mxfp4_grouped_matmul_amd_preb which expects B in the 5D preshuffled layout from Shuffler.preshuffle_b_5d (typically produced by the model's weight adapter at load time, e.g. Kimi K2.5). When False (default), dispatches to the dense mxfp4_grouped_matmul_amd kernel that reads B row-major. The caller is responsible for preparing B in the matching layout.

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static def execute[c_type: DType, //, target: StringSlice[StaticConstantOrigin]](c: ManagedTensorSlice[Output, static_spec=c.static_spec], a: ManagedTensorSlice[Input, static_spec=a.static_spec], b: ManagedTensorSlice[Input, static_spec=b.static_spec], a_scales: ManagedTensorSlice[Input, static_spec=a_scales.static_spec], b_scales: ManagedTensorSlice[Input, static_spec=b_scales.static_spec], expert_start_indices: ManagedTensorSlice[Input, static_spec=expert_start_indices.static_spec], expert_ids: ManagedTensorSlice[Input, static_spec=expert_ids.static_spec], max_num_tokens_per_expert: UInt32, num_active_experts: UInt32, estimated_total_m: UInt32, context: DeviceContext)

Executes grouped block-scaled matrix multiplication.

Computes C = A @ B^T for multiple expert groups where A and B are block-scaled (e.g. MXFP4: 4-bit floating point packed as uint8).

Parameters:

  • ​c_type (DType): The output tensor data type.
  • ​target (StringSlice[StaticConstantOrigin]): The target GPU device.

Args: