Skip to main content

Mojo function

grouped_matmul_dynamic_scaled_nvfp4

grouped_matmul_dynamic_scaled_nvfp4[c_type: DType, a_type: DType, b_type: DType, scales_type: DType, //, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](c: TileTensor[c_type, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a: TileTensor[a_type, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, linear_idx_type=b.linear_idx_type, element_size=b.element_size], a_scales: TileTensor[scales_type, linear_idx_type=a_scales.linear_idx_type, element_size=a_scales.element_size], b_scales: TileTensor[scales_type, linear_idx_type=b_scales.linear_idx_type, element_size=b_scales.element_size], group_offsets: TileTensor[DType.uint32, linear_idx_type=group_offsets.linear_idx_type, element_size=group_offsets.element_size], group_scale_offsets: TileTensor[DType.uint32, linear_idx_type=group_scale_offsets.linear_idx_type, element_size=group_scale_offsets.element_size], expert_ids: TileTensor[DType.int32, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], expert_scales: TileTensor[DType.float32, linear_idx_type=expert_scales.linear_idx_type, element_size=expert_scales.element_size], num_active_experts: Int, ctx: DeviceContext)

Performs grouped matrix multiplication with NVFP4 quantization.

Computes C = A @ B^T for multiple expert groups in a Mixture of Experts (MoE) layer. Inputs A and B are NVFP4 quantized (4-bit floating point), packed as uint8 (2 values per byte), with float8_e4m3fn scale factors. Each group of 16 elements along the K dimension shares a single scale factor (1D block scaling).

Accepts TileTensors and converts to LayoutTensors internally.

Constraints:

  • The target device must be SM100 (B200).

Parameters:

  • ​c_type (DType): The data type of the output tensor C.
  • ​a_type (DType): The data type of input tensor A. Constraints: Must be uint8.
  • ​b_type (DType): The data type of input tensor B. Constraints: Must be uint8.
  • ​scales_type (DType): The data type of scale factors. Constraints: Must be float8_e4m3fn.
  • ​transpose_b (Bool): Whether B is transposed. Constraints: Must be True.
  • ​target (StringSlice[StaticConstantOrigin]): The target device.

Args: