For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
common
Shared NVIDIA GPU attention primitives used by both SM90 and SM100 kernels.
This module hosts helpers that are not architecture-specific so that neither
sm90/ nor sm100/ has to import from the other. It currently provides:
elect(): single-lane election via theelect.syncPTX instruction.
comptime valuesβ
ImmutTileTensor1Dβ
comptime ImmutTileTensor1D[dtype: DType] = TileTensor[dtype, Layout[*?, *?], ImmutAnyOrigin]
Parametersβ
- βdtype (
DType):
KVTMATileβ
comptime KVTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int] = TMATensorTile[dtype, Int(3), _padded_shape[Int(3), dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]
Parametersβ
- βdtype (
DType): - βswizzle_mode (
TensorMapSwizzle): - βBN (
Int): - βBK (
Int):
QTMATileβ
comptime QTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, group: Int, decoding: Bool, fuse_gqa: Bool = False, num_qk_stages: Int = Int(1)] = TMATensorTile[dtype, Int(4) if decoding or fuse_gqa else Int(3), _padded_shape[Int(4) if decoding or fuse_gqa else Int(3), dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, fuse_gqa=fuse_gqa, num_qk_stages=num_qk_stages](), swizzle_mode](), _ragged_shape[Int(4) if decoding or fuse_gqa else Int(3), dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, fuse_gqa=fuse_gqa, num_qk_stages=num_qk_stages](), swizzle_mode]()]
Parametersβ
Structsβ
- β
MHAPosition: Position of the MHA-kernel. Whendecoding=False,q_head_stride == q_num_heads. Whendecoding=True,q_head_stride == 1. - β
NonNullPointer: - β
NullPointer: - β
Pack: - β
PositionSummary:
Traitsβ
Functionsβ
- β
elect: - β
get_seq_info: - β
kv_coord: - β
output_reg_to_smem_st_matrix: - β
q_coord: Returns the coordinates for a tma load on theQmatrix. This load can be 3D, 4D, or 5D. - β
q_gmem_shape: - β
q_smem_shape: - β
q_tma:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!