IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

shrink_qkv_permute_3mn_sm100

def shrink_qkv_permute_3mn_sm100(c_lora: TileTensor[Storage=c_lora.Storage, linear_idx_type=c_lora.linear_idx_type, element_size=c_lora.element_size], a: TileTensor[Storage=a.Storage, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[Storage=b.Storage, linear_idx_type=b.linear_idx_type, element_size=b.element_size], a_offsets: TileTensor[DType.uint32, Storage=a_offsets.Storage, linear_idx_type=a_offsets.linear_idx_type, element_size=a_offsets.element_size], expert_ids: TileTensor[DType.int32, Storage=expert_ids.Storage, linear_idx_type=expert_ids.linear_idx_type, element_size=expert_ids.element_size], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)

TileTensor primary implementation of shrink_qkv_permute_3mn_sm100.

LoRA shrink GMM with planar Q/K/V output on SM100.

Performs the LoRA 'shrink' grouped matmul for routed tokens: computes [M, K] @ [G, 3N, K]^T per active expert, then permutes the flat [M, 3N] result into a planar layout [3, M, N] (Q, K, V) using an elementwise epilogue, while reusing the same storage.

Constraints:

  • c_lora must be rank 3 with static first dimension B == 3.
  • a must be rank 2 with trailing dimension K that matches b[..., K].
  • b must be rank 3 with shape (G, 3N, K).
  • The temporary 2D view of c_lora is (M, 3N) in row-major order and aliases the same storage as c_lora.
  • a_offsets is non-decreasing with a_offsets[0] == 0 and a_offsets[num_active_experts] == M.
  • expert_ids[i] in [0, G) for valid experts; kernel may treat -1 as inactive.
  • The epilogue assumes N % vector_width == 0 for aligned vector stores.

Args: