Skip to main content

Mojo function

grouped_matmul_dynamic_scaled_nvfp4

grouped_matmul_dynamic_scaled_nvfp4[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, scales_type: DType, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, a_scale_offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, //, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin], b_scales: LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)

Performs grouped matrix multiplication with NVFP4 quantization.

Computes C = A @ B^T for multiple expert groups in a Mixture of Experts (MoE) layer. Inputs A and B are NVFP4 quantized (4-bit floating point), packed as uint8 (2 values per byte), with float8_e4m3fn scale factors. Each group of 16 elements along the K dimension shares a single scale factor (1D block scaling).

Constraints:

  • The target device must be SM100 (B200).

Parameters:

  • c_type (DType): The data type of the output tensor C.
  • c_layout (Layout): The memory layout of the output tensor C.
  • a_type (DType): The data type of input tensor A. Constraints: Must be uint8.
  • a_layout (Layout): The memory layout of input tensor A.
  • b_type (DType): The data type of input tensor B. Constraints: Must be uint8.
  • b_layout (Layout): The memory layout of input tensor B.
  • scales_type (DType): The data type of scale factors. Constraints: Must be float8_e4m3fn.
  • a_scales_layout (Layout): The memory layout of A's scale factors.
  • b_scales_layout (Layout): The memory layout of B's scale factors.
  • a_offsets_layout (Layout): The memory layout of the token offset indices.
  • a_scale_offsets_layout (Layout): The memory layout of A's scale offset indices.
  • expert_ids_layout (Layout): The memory layout of the expert ID tensor.
  • expert_scales_layout (Layout): The memory layout of the per-expert scale tensor.
  • transpose_b (Bool): Whether B is transposed. Constraints: Must be True.
  • target (StringSlice): The target device.

Args:

  • c (LayoutTensor): The output tensor of shape (total_tokens, N).
  • a (LayoutTensor): The input tensor of shape (total_tokens, K // 2), packed NVFP4.
  • b (LayoutTensor): The weight tensor of shape (num_experts, N, K // 2), packed NVFP4.
  • a_scales (LayoutTensor): The scale factors for A in tcgen05 5D layout.
  • b_scales (LayoutTensor): The scale factors for B in tcgen05 6D layout.
  • a_offsets (LayoutTensor): The starting token index for each expert group.
  • a_scale_offsets (LayoutTensor): The starting scale index for each expert group.
  • expert_ids (LayoutTensor): The expert ID for each group.
  • expert_scales (LayoutTensor): The per-expert scaling factors applied in the epilogue.
  • num_active_experts (Int): The number of active experts in this batch.
  • ctx (DeviceContext): The device context for GPU execution.

Was this page helpful?