Skip to main content

Mojo function

grouped_matmul_dynamic_scaled_nvfp4

grouped_matmul_dynamic_scaled_nvfp4[c_type: DType, a_type: DType, b_type: DType, scales_type: DType, //, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](c: TileTensor[c_type, c.LayoutType, c.origin, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a: TileTensor[a_type, a.LayoutType, a.origin, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, b.LayoutType, b.origin, linear_idx_type=b.linear_idx_type, element_size=b.element_size], a_scales: TileTensor[scales_type, a_scales.LayoutType, a_scales.origin, linear_idx_type=a_scales.linear_idx_type, element_size=a_scales.element_size], b_scales: TileTensor[scales_type, b_scales.LayoutType, b_scales.origin, linear_idx_type=b_scales.linear_idx_type, element_size=b_scales.element_size], group_offsets: TileTensor[DType.uint32, group_offsets.LayoutType, group_offsets.origin, linear_idx_type=group_offsets.linear_idx_type, element_size=group_offsets.element_size], group_scale_offsets: TileTensor[DType.uint32, group_scale_offsets.LayoutType, group_scale_offsets.origin, linear_idx_type=group_scale_offsets.linear_idx_type, element_size=group_scale_offsets.element_size], expert_ids: TileTensor[DType.int32, expert_ids.LayoutType, expert_ids.origin, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], expert_scales: TileTensor[DType.float32, expert_scales.LayoutType, expert_scales.origin, linear_idx_type=expert_scales.linear_idx_type, element_size=expert_scales.element_size], num_active_experts: Int, ctx: DeviceContext)

Performs grouped matrix multiplication with NVFP4 quantization.

Computes C = A @ B^T for multiple expert groups in a Mixture of Experts (MoE) layer. Inputs A and B are NVFP4 quantized (4-bit floating point), packed as uint8 (2 values per byte), with float8_e4m3fn scale factors. Each group of 16 elements along the K dimension shares a single scale factor (1D block scaling).

Accepts TileTensors and converts to LayoutTensors internally.

Constraints:

  • The target device must be SM100 (B200).

Parameters:

  • c_type (DType): The data type of the output tensor C.
  • a_type (DType): The data type of input tensor A. Constraints: Must be uint8.
  • b_type (DType): The data type of input tensor B. Constraints: Must be uint8.
  • scales_type (DType): The data type of scale factors. Constraints: Must be float8_e4m3fn.
  • transpose_b (Bool): Whether B is transposed. Constraints: Must be True.
  • target (StringSlice): The target device.

Args:

  • c (TileTensor): The output tensor of shape (total_tokens, N).
  • a (TileTensor): The input tensor of shape (total_tokens, K // 2), packed NVFP4.
  • b (TileTensor): The weight tensor of shape (num_experts, N, K // 2), packed NVFP4.
  • a_scales (TileTensor): The scale factors for A in tcgen05 5D layout.
  • b_scales (TileTensor): The scale factors for B in tcgen05 6D layout.
  • group_offsets (TileTensor): The starting token index for each expert group.
  • group_scale_offsets (TileTensor): The starting scale index for each expert group.
  • expert_ids (TileTensor): The expert ID for each group.
  • expert_scales (TileTensor): The per-expert scaling factors applied in the epilogue.
  • num_active_experts (Int): The number of active experts in this batch.
  • ctx (DeviceContext): The device context for GPU execution.

Was this page helpful?