For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
build_k2q_csr_device
def build_k2q_csr_device[topk: Int](q2k: UnsafePointer[Int32, MutAnyOrigin], cu_seqlens_q: UnsafePointer[Int32, MutAnyOrigin], cu_seqlens_k: UnsafePointer[Int32, MutAnyOrigin], row_ptr: UnsafePointer[Int32, MutAnyOrigin], qsplit_indices: UnsafePointer[Int32, MutAnyOrigin], scheduler_metadata: UnsafePointer[Int32, MutAnyOrigin], work_count: UnsafePointer[Int32, MutAnyOrigin], split_counts: UnsafePointer[Int32, MutAnyOrigin], row_map: UnsafePointer[Int32, MutAnyOrigin], row_coords: UnsafePointer[Int32, MutAnyOrigin], row_counts: UnsafePointer[Int32, MutAnyOrigin], tile_counts: UnsafePointer[Int32, MutAnyOrigin], head_kv: Int, total_q: Int, blk_kv: Int, max_seqlen_q: Int, sizes: K2qCsrDeviceSizes, ctx: DeviceContext, q_per_cta: Int = Int(128))
Builds the reverse-CSR + schedule on the device into caller buffers.
Emits the identical contract tensors as host build_k2q_csr for the same
q2k: row_ptr byte-identical, qsplit_indices q-ascending within each
row, scheduler_metadata/work_count, split_counts.
The outputs and scratch must be pre-zeroed where the host pre-zeroes
(row_counts, work_count); qsplit_indices must be -1 filled
(unused tail); split_counts zero-filled. q_per_cta here is the fwd-CTA
q-chunk cap (the scheduler chunking), distinct from the hist/scatter
sizes.q_per_cta that tiles the q-range across CTAs.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!