Skip to main content

Mojo function

shrink_qkv_permute_3mn_sm100

shrink_qkv_permute_3mn_sm100[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList](c_lora: NDBuffer[c_type, 3, MutableAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutableAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutableAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutableAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutableAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)

LoRA shrink GMM with planar Q/K/V output on SM100.

Performs the LoRA 'shrink' grouped matmul for routed tokens: computes [M, K] @ [G, 3N, K]^T per active expert, then permutes the flat [M, 3N] result into a planar layout [3, M, N] (Q, K, V) using an elementwise epilogue, while reusing the same storage.

Constraints:

  • c_lora must be rank 3 with static first dimension B == 3.
  • a must be rank 2 with trailing dimension K that matches b[..., K].
  • b must be rank 3 with shape (G, 3N, K).
  • The temporary 2D view of c_lora is (M, 3N) in row-major order and aliases the same storage as c_lora.
  • a_offsets is non-decreasing with a_offsets[0] == 0 and a_offsets[num_active_experts] == M.
  • expert_ids[i] ∈ [0, G) for valid experts; kernel may treat -1 as inactive.
  • The epilogue assumes N % vector_width == 0 for aligned vector stores.

Args:

  • c_lora (NDBuffer): Output tensor with planar Q/K/V layout, shape (3, M, N). Backed by row-major storage, used both as a 3D view and as a temporary 2D view (M, 3N) during compute.
  • a (NDBuffer): Routed activation matrix, shape (M, K).
  • b (NDBuffer): Shrink weights per expert, shape (G, 3N, K).
  • a_offsets (NDBuffer): Inclusive prefix sums of tokens per (active) expert, length (num_experts + 1). Defines per-expert [start, end) in A/C.
  • expert_ids (NDBuffer): Expert indices for the active groups, length ≥ num_active_experts.
  • max_num_tokens_per_expert (Int): Upper bound on tokens for any active expert.
  • num_active_experts (Int): Number of experts participating in this call.
  • ctx (DeviceContext): DeviceContext used for enqueues and synchronization.

Was this page helpful?