Mojo function
grouped_matmul_dynamic_scaled_nvfp4
grouped_matmul_dynamic_scaled_nvfp4[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, scales_type: DType, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, a_scale_offsets_layout: Layout, expert_ids_layout: Layout, expert_scales_layout: Layout, //, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin], b_scales: LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], a_scale_offsets: LayoutTensor[DType.uint32, a_scale_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_scales: LayoutTensor[DType.float32, expert_scales_layout, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)
Performs grouped matrix multiplication with NVFP4 quantization.
Computes C = A @ B^T for multiple expert groups in a Mixture of Experts (MoE) layer. Inputs A and B are NVFP4 quantized (4-bit floating point), packed as uint8 (2 values per byte), with float8_e4m3fn scale factors. Each group of 16 elements along the K dimension shares a single scale factor (1D block scaling).
Constraints:
- The target device must be SM100 (B200).
Parameters:
- c_type (
DType): The data type of the output tensor C. - c_layout (
Layout): The memory layout of the output tensor C. - a_type (
DType): The data type of input tensor A. Constraints: Must beuint8. - a_layout (
Layout): The memory layout of input tensor A. - b_type (
DType): The data type of input tensor B. Constraints: Must beuint8. - b_layout (
Layout): The memory layout of input tensor B. - scales_type (
DType): The data type of scale factors. Constraints: Must befloat8_e4m3fn. - a_scales_layout (
Layout): The memory layout of A's scale factors. - b_scales_layout (
Layout): The memory layout of B's scale factors. - a_offsets_layout (
Layout): The memory layout of the token offset indices. - a_scale_offsets_layout (
Layout): The memory layout of A's scale offset indices. - expert_ids_layout (
Layout): The memory layout of the expert ID tensor. - expert_scales_layout (
Layout): The memory layout of the per-expert scale tensor. - transpose_b (
Bool): Whether B is transposed. Constraints: Must beTrue. - target (
StringSlice): The target device.
Args:
- c (
LayoutTensor): The output tensor of shape (total_tokens, N). - a (
LayoutTensor): The input tensor of shape (total_tokens, K // 2), packed NVFP4. - b (
LayoutTensor): The weight tensor of shape (num_experts, N, K // 2), packed NVFP4. - a_scales (
LayoutTensor): The scale factors for A in tcgen05 5D layout. - b_scales (
LayoutTensor): The scale factors for B in tcgen05 6D layout. - a_offsets (
LayoutTensor): The starting token index for each expert group. - a_scale_offsets (
LayoutTensor): The starting scale index for each expert group. - expert_ids (
LayoutTensor): The expert ID for each group. - expert_scales (
LayoutTensor): The per-expert scaling factors applied in the epilogue. - num_active_experts (
Int): The number of active experts in this batch. - ctx (
DeviceContext): The device context for GPU execution.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!