For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Struct_grouped_matmul_swiglu_nvfp4
struct Struct_grouped_matmul_swiglu_nvfp4
MOGG wrapper for fused grouped NVFP4 matmul + SwiGLU + NVFP4 quant.
Fuses the MoE gate/up grouped matmul, SwiGLU activation, and per-block
NVFP4 quantization into a single SM100 kernel. The caller must pre-permute
the weight b and its scale tile b_scales on the N axis with
sigma(2i)=i, sigma(2i+1)=D+i (where D = moe_dim, N = 2D).
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static def execute[a_type: DType, b_type: DType, scales_type: DType, //, target: StringSlice[StaticConstantOrigin]](c_packed: ManagedTensorSlice[Output, static_spec=c_packed.static_spec], c_swiglu_scales: ManagedTensorSlice[Output, static_spec=c_swiglu_scales.static_spec], a: ManagedTensorSlice[Input, static_spec=a.static_spec], b: ManagedTensorSlice[Input, static_spec=b.static_spec], a_scales: ManagedTensorSlice[Input, static_spec=a_scales.static_spec], b_scales: ManagedTensorSlice[Input, static_spec=b_scales.static_spec], expert_start_indices: ManagedTensorSlice[Input, static_spec=expert_start_indices.static_spec], expert_ids: ManagedTensorSlice[Input, static_spec=expert_ids.static_spec], a_scale_offsets: ManagedTensorSlice[Input, static_spec=a_scale_offsets.static_spec], expert_scales: ManagedTensorSlice[Input, static_spec=expert_scales.static_spec], c_input_scales: ManagedTensorSlice[Input, static_spec=c_input_scales.static_spec], estimated_total_m: UInt32, num_active_experts: UInt32, context: DeviceContext)
Executes fused grouped NVFP4 matmul + SwiGLU + NVFP4 quant.
Computes (c_packed, c_swiglu_scales) = quantize_nvfp4(silu(C[..., even]) * C[..., odd], c_input_scales)
where C = A @ B^T for multiple expert groups. Because B is
sigma-permuted on N, adjacent matmul-output columns carry
(gate, up) pairs that the epilogue consumes in-place.
Parameters:
- βa_type (
DType): The input A data type. Constraints: Must beuint8. - βb_type (
DType): The input B data type. Constraints: Must beuint8. - βscales_type (
DType): The scale factor data type. Constraints: Must befloat8_e4m3fn. - βtarget (
StringSlice[StaticConstantOrigin]): The target GPU device.
Args:
- βc_packed (
ManagedTensorSlice[Output, static_spec=c_packed.static_spec]): Packed NVFP4 output of shape (total_tokens, D // 2). - βc_swiglu_scales (
ManagedTensorSlice[Output, static_spec=c_swiglu_scales.static_spec]): 5D FP8 SF tile in tcgen05 layout for the output. - βa (
ManagedTensorSlice[Input, static_spec=a.static_spec]): The input tensor of shape (total_tokens, K // 2). - βb (
ManagedTensorSlice[Input, static_spec=b.static_spec]): The sigma-permuted weight of shape (num_experts, 2D, K // 2). - βa_scales (
ManagedTensorSlice[Input, static_spec=a_scales.static_spec]): The A scale factors in tcgen05 5D layout. - βb_scales (
ManagedTensorSlice[Input, static_spec=b_scales.static_spec]): The sigma-permuted B scale factors in tcgen05 6D layout. - βexpert_start_indices (
ManagedTensorSlice[Input, static_spec=expert_start_indices.static_spec]): The starting token index for each expert. - βexpert_ids (
ManagedTensorSlice[Input, static_spec=expert_ids.static_spec]): The expert ID for each group. - βa_scale_offsets (
ManagedTensorSlice[Input, static_spec=a_scale_offsets.static_spec]): The starting scale index for each expert. - βexpert_scales (
ManagedTensorSlice[Input, static_spec=expert_scales.static_spec]): The per-expert scaling factors for the epilogue. - βc_input_scales (
ManagedTensorSlice[Input, static_spec=c_input_scales.static_spec]): Per-expert SiLU input scale (= 1/output_inv_scale). - βestimated_total_m (
UInt32): The estimated total number of tokens. - βnum_active_experts (
UInt32): The number of active experts. - βcontext (
DeviceContext): The device context pointer.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!