For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo trait
MHAOperand
This serves as the trait to support arguments to our MHA kernel.
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type
Indicate the type being used on accelerator devices.
dtypeβ
comptime dtype
page_sizeβ
comptime page_size
quantization_enabledβ
comptime quantization_enabled = False
quantization_granularityβ
comptime quantization_granularity
scale_dtypeβ
comptime scale_dtype
Required methodsβ
__init__β
def __init__(out self: _Self, *, copy: _Self)
Create a new instance of the value by copying an existing one.
Args:
- βcopy (
_Self): The value to copy.
Returns:
_Self
def __init__(out self: _Self, *, deinit move: _Self)
Create a new instance of the value by moving the value of another.
Args:
- βmove (
_Self): The value to move.
Returns:
_Self
block_paged_ptrβ
def block_paged_ptr[tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[_Self.dtype], ImmutAnyOrigin]
Returns:
scales_block_paged_ptrβ
def scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], ImmutAnyOrigin]
Returns:
load_scaleβ
def load_scale[width: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]
Returns:
cache_lengthβ
def cache_length(self: _Self, batch_idx: Int) -> Int
Returns the length of the cache for a given batch index.
Returns:
max_context_lengthβ
def max_context_length(self: _Self) -> UInt32
Returns the maximum cache length in a given batch index.
Returns:
num_kv_rowsβ
def num_kv_rows(self: _Self) -> Int
Returns the total number of virtual rows in the KV memory view.
For paged caches this accounts for the paging stride so that TMA descriptors can be sized to cover the entire address space.
Returns:
row_idxβ
def row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
get_tma_rowβ
def get_tma_row(self: _Self, encoded_index: Int32) -> Int32
Convert an encoded sparse index to a physical TMA row.
For paged caches the encoded index is
physical_block * page_size + offset and this method returns
physical_block * stride + offset. Non-paged operands return
the encoded index unchanged.
Returns:
create_tma_tileβ
def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]
Creates a TMA tile for efficient GPU memory transfers. This is useful for k-major MMA operations where we don't need to mask any extra rows.
Returns:
create_scale_tma_tileβ
def create_scale_tma_tile[BMN: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](1, BMN)]
Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.
Returns:
TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](1, BMN)]
create_ragged_tma_tileβ
def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BM=BN, BN=BK]
Creates a TMA tile for efficient GPU memory transfers. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.
Returns:
create_rope_tma_tileβ
def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]
Creates a BF16 TMA tile for the rope portion of the per-tensor rope-aware KV cache.
Returns:
create_gather4_tma_tileβ
def create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, tma_dtype: DType = _Self.dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]
Creates a 2D TMA gather4 descriptor for this operand.
The descriptor views the data as a flat 2D matrix of
[num_kv_rows, tile_width] and is configured for gather4 operations
that load 4 non-contiguous rows per TMA instruction. The box width
is derived from the swizzle mode; for SWIZZLE_NONE it equals
tile_width.
When tma_dtype differs from Self.dtype, the underlying data
pointer is bitcast to tma_dtype at descriptor creation time.
Parameters:
- βtile_width (
Int): Number of elements per row to load (box width) intma_dtypeelements. - βtile_stride (
Int): Row stride in elements in global memory. Defaults totile_width. Use a larger value when the global row is wider than the portion to load. - βswizzle_mode (
TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE. - βtile_height (
Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility. - βtma_dtype (
DType): The data type used for the TMA descriptor. Defaults toSelf.dtype. When different, the pointer is bitcast. - βl2_promotion (
TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.
Args:
- βctx (
DeviceContext): The CUDA device context used to create the TMA descriptor.
Returns:
TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A TMATensorTile with box width derived from the swizzle mode.
create_rope_gather4_tma_tileβ
def create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]
Creates a BF16 gather4 TMA descriptor for the rope portion of the KV cache.
For the per-tensor rope-aware layout each token row is
padded_depth FP8 bytes (content) followed by BF16 rope elements.
This method offsets the base pointer by padded_depth bytes,
reinterprets as BF16, and creates a gather4 TMA descriptor.
Parameters:
- βtile_width (
Int): Number of BF16 elements per row in global memory. - βpadded_depth (
Int): Byte offset from row start to the rope data. - βswizzle_mode (
TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. - βtile_height (
Int): Number of rows in the tile. Must be a multiple of 4. - βl2_promotion (
TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.
Args:
- βctx (
DeviceContext): The CUDA device context used to create the TMA descriptor.
Returns:
TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A BF16 TMATensorTile configured for gather4.
scales_raw_ptrβ
def scales_raw_ptr(self: _Self) -> UnsafePointer[Float32, MutAnyOrigin]
Returns the base pointer to the quantization scales tensor.
Returns a null pointer for operands without quantization support.
Returns:
get_type_nameβ
static def get_type_name() -> String
Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.
Returns:
String: The host type's name.
Provided methodsβ
block_paged_tileβ
def block_paged_tile[layout_t: TensorLayout, //, tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, layout_val: layout_t, head_dim_idx: UInt32 = UInt32(0)) -> TileTensor[_Self.dtype, layout_t, ImmutAnyOrigin]
Wraps block_paged_ptr in a TileTensor with the caller's layout.
Returns:
populateβ
def populate[BN: Int, base_alignment: Int, pair_cta: Bool = False, is_leader: Bool = True](self: _Self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, _Self.page_size, pair_cta, is_leader]
Populate a full PagedRowIndices[BN, ...] for a BN-row tile.
Returns the precomputed physical row indices for the num_pages
sub-tile pages covering the BN-row range starting at
base_kv_row for batch_idx. Both K's TMA (which may cover only
a subset in pair_cta mode) and V's TMA (which covers the full
range) can then consume the result without any lazy LUT lookup.
base_alignment is a comptime promise that
base_kv_row % base_alignment == 0 at runtime β typically
mask.start_column_alignment[...](). The PagedKVCache
override uses it to pick the largest legal SIMD chunk for its
LUT vector load and to skip the intra-page divmod when
base_alignment % page_size == 0.
Default implementation: scalar loop over num_pages calls to
row_idx. Overrides (e.g. PagedKVCache) replace this with a
single SIMD load from the underlying lookup table.
Returns:
PagedRowIndices[BN, _Self.page_size, pair_cta, is_leader]
copyβ
def copy(self: _Self) -> _Self
Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.
Returns:
_Self: A copy of this value.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!