Mojo function
shrink_qkv_permute_3mn_sm100
shrink_qkv_permute_3mn_sm100[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList](c_lora: NDBuffer[c_type, 3, MutableAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutableAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutableAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutableAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutableAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)
LoRA shrink GMM with planar Q/K/V output on SM100.
Performs the LoRA 'shrink' grouped matmul for routed tokens:
computes [M, K] @ [G, 3N, K]^T
per active expert, then permutes
the flat [M, 3N]
result into a planar layout [3, M, N]
(Q, K, V)
using an elementwise epilogue, while reusing the same storage.
Constraints:
- c_lora must be rank 3 with static first dimension B == 3.
- a must be rank 2 with trailing dimension K that matches b[..., K].
- b must be rank 3 with shape (G, 3N, K).
- The temporary 2D view of c_lora is (M, 3N) in row-major order and aliases the same storage as c_lora.
- a_offsets is non-decreasing with a_offsets[0] == 0 and a_offsets[num_active_experts] == M.
- expert_ids[i] ∈ [0, G) for valid experts; kernel may treat -1 as inactive.
- The epilogue assumes
N % vector_width == 0
for aligned vector stores.
Args:
- c_lora (
NDBuffer
): Output tensor with planar Q/K/V layout, shape (3, M, N). Backed by row-major storage, used both as a 3D view and as a temporary 2D view (M, 3N) during compute. - a (
NDBuffer
): Routed activation matrix, shape (M, K). - b (
NDBuffer
): Shrink weights per expert, shape (G, 3N, K). - a_offsets (
NDBuffer
): Inclusive prefix sums of tokens per (active) expert, length (num_experts + 1). Defines per-expert [start, end) in A/C. - expert_ids (
NDBuffer
): Expert indices for the active groups, length ≥ num_active_experts. - max_num_tokens_per_expert (
Int
): Upper bound on tokens for any active expert. - num_active_experts (
Int
): Number of experts participating in this call. - ctx (
DeviceContext
): DeviceContext used for enqueues and synchronization.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!