IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Struct_grouped_matmul_block_scaled

struct Struct_grouped_matmul_block_scaled

MOGG wrapper for grouped block-scaled matrix multiplication.

Provides graph compiler integration for block-scaled grouped matmul operations used in Mixture of Experts (MoE) layers on SM100 GPUs.

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static def execute[c_type: DType, a_type: DType, b_type: DType, scales_type: DType, //, target: StringSlice[StaticConstantOrigin]](c: ManagedTensorSlice[Output, static_spec=c.static_spec], a: ManagedTensorSlice[Input, static_spec=a.static_spec], b: ManagedTensorSlice[Input, static_spec=b.static_spec], a_scales: ManagedTensorSlice[Input, static_spec=a_scales.static_spec], b_scales: ManagedTensorSlice[Input, static_spec=b_scales.static_spec], expert_start_indices: ManagedTensorSlice[Input, static_spec=expert_start_indices.static_spec], expert_ids: ManagedTensorSlice[Input, static_spec=expert_ids.static_spec], a_scale_offsets: ManagedTensorSlice[Input, static_spec=a_scale_offsets.static_spec], expert_scales: ManagedTensorSlice[Input, static_spec=expert_scales.static_spec], estimated_total_m: UInt32, num_active_experts: UInt32, context: DeviceContext)

Executes grouped block-scaled matrix multiplication.

Computes C = A @ B^T for multiple expert groups where A and B are block-scaled (e.g. NVFP4: 4-bit floating point packed as uint8).

Parameters:

  • ​c_type (DType): The output tensor data type.
  • ​a_type (DType): The input A data type. Constraints: Must be uint8.
  • ​b_type (DType): The input B data type. Constraints: Must be uint8.
  • ​scales_type (DType): The scale factor data type. Constraints: Must be float8_e4m3fn.
  • ​target (StringSlice[StaticConstantOrigin]): The target GPU device.

Args: