Mojo function
grouped_matmul_dynamic_scaled_nvfp4
grouped_matmul_dynamic_scaled_nvfp4[c_type: DType, a_type: DType, b_type: DType, scales_type: DType, //, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](c: TileTensor[c_type, c.LayoutType, c.origin, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a: TileTensor[a_type, a.LayoutType, a.origin, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, b.LayoutType, b.origin, linear_idx_type=b.linear_idx_type, element_size=b.element_size], a_scales: TileTensor[scales_type, a_scales.LayoutType, a_scales.origin, linear_idx_type=a_scales.linear_idx_type, element_size=a_scales.element_size], b_scales: TileTensor[scales_type, b_scales.LayoutType, b_scales.origin, linear_idx_type=b_scales.linear_idx_type, element_size=b_scales.element_size], group_offsets: TileTensor[DType.uint32, group_offsets.LayoutType, group_offsets.origin, linear_idx_type=group_offsets.linear_idx_type, element_size=group_offsets.element_size], group_scale_offsets: TileTensor[DType.uint32, group_scale_offsets.LayoutType, group_scale_offsets.origin, linear_idx_type=group_scale_offsets.linear_idx_type, element_size=group_scale_offsets.element_size], expert_ids: TileTensor[DType.int32, expert_ids.LayoutType, expert_ids.origin, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], expert_scales: TileTensor[DType.float32, expert_scales.LayoutType, expert_scales.origin, linear_idx_type=expert_scales.linear_idx_type, element_size=expert_scales.element_size], num_active_experts: Int, ctx: DeviceContext)
Performs grouped matrix multiplication with NVFP4 quantization.
Computes C = A @ B^T for multiple expert groups in a Mixture of Experts (MoE) layer. Inputs A and B are NVFP4 quantized (4-bit floating point), packed as uint8 (2 values per byte), with float8_e4m3fn scale factors. Each group of 16 elements along the K dimension shares a single scale factor (1D block scaling).
Accepts TileTensors and converts to LayoutTensors internally.
Constraints:
- The target device must be SM100 (B200).
Parameters:
- c_type (
DType): The data type of the output tensor C. - a_type (
DType): The data type of input tensor A. Constraints: Must beuint8. - b_type (
DType): The data type of input tensor B. Constraints: Must beuint8. - scales_type (
DType): The data type of scale factors. Constraints: Must befloat8_e4m3fn. - transpose_b (
Bool): Whether B is transposed. Constraints: Must beTrue. - target (
StringSlice): The target device.
Args:
- c (
TileTensor): The output tensor of shape (total_tokens, N). - a (
TileTensor): The input tensor of shape (total_tokens, K // 2), packed NVFP4. - b (
TileTensor): The weight tensor of shape (num_experts, N, K // 2), packed NVFP4. - a_scales (
TileTensor): The scale factors for A in tcgen05 5D layout. - b_scales (
TileTensor): The scale factors for B in tcgen05 6D layout. - group_offsets (
TileTensor): The starting token index for each expert group. - group_scale_offsets (
TileTensor): The starting scale index for each expert group. - expert_ids (
TileTensor): The expert ID for each group. - expert_scales (
TileTensor): The per-expert scaling factors applied in the epilogue. - num_active_experts (
Int): The number of active experts in this batch. - ctx (
DeviceContext): The device context for GPU execution.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!