Mojo trait
MHAOperand
This serves as the trait to support arguments to our MHA kernel.
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type
Indicate the type being used on accelerator devices.
dtypeβ
comptime dtype
page_sizeβ
comptime page_size
quantization_enabledβ
comptime quantization_enabled = False
quantization_granularityβ
comptime quantization_granularity
scale_dtypeβ
comptime scale_dtype
Required methodsβ
__init__β
__init__(out self: _Self, *, copy: _Self)
Create a new instance of the value by copying an existing one.
Args:
- βcopy (
_Self): The value to copy.
Returns:
_Self
__init__(out self: _Self, *, deinit take: _Self)
Create a new instance of the value by moving the value of another.
Args:
- βtake (
_Self): The value to move.
Returns:
_Self
block_paged_ptrβ
block_paged_ptr[tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[_Self.dtype], ImmutAnyOrigin]
Returns:
scales_block_paged_ptrβ
scales_block_paged_ptr(self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[_Self.scale_dtype], ImmutAnyOrigin]
Returns:
load_scaleβ
load_scale[width: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[_Self.scale_dtype, width]
Returns:
cache_lengthβ
cache_length(self: _Self, batch_idx: Int) -> Int
Returns the length of the cache for a given batch index.
Returns:
max_context_lengthβ
max_context_length(self: _Self) -> UInt32
Returns the maximum cache length in a given batch index.
Returns:
num_kv_rowsβ
num_kv_rows(self: _Self) -> Int
Returns the total number of virtual rows in the KV memory view.
For paged caches this accounts for the paging stride so that TMA descriptors can be sized to cover the entire address space.
Returns:
row_idxβ
row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
get_tma_rowβ
get_tma_row(self: _Self, encoded_index: Int32) -> Int32
Convert an encoded sparse index to a physical TMA row.
For paged caches the encoded index is
physical_block * page_size + offset and this method returns
physical_block * stride + offset. Non-paged operands return
the encoded index unchanged.
Returns:
create_tma_tileβ
create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, 3, _padded_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, _Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]
Creates a TMA tile for efficient GPU memory transfers. This is useful for k-major MMA operations where we don't need to mask any extra rows.
Returns:
create_scale_tma_tileβ
create_scale_tma_tile[BMN: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](1, BMN)]
Creates a TMA tile for efficient GPU memory transfers. This is useful for m-major MMA operations where we don't need to mask any extra rows.
Returns:
TMATensorTile[_Self.scale_dtype, 2, Index[Int, Int](1, BMN)]
create_ragged_tma_tileβ
create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[_Self.dtype, swizzle_mode, depth]()](self: _Self, ctx: DeviceContext) -> RaggedTMA3DTile[_Self.dtype, swizzle_mode, BM=BN, BN=BK]
Creates a TMA tile for efficient GPU memory transfers. This is useful for mn-major MMA operations where we need to mask extra rows to avoid adding NaN to the output through the MMA reduction.
Returns:
create_rope_tma_tileβ
create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]
Creates a BF16 TMA tile for the rope portion of the per-tensor rope-aware KV cache.
Returns:
create_gather4_tma_tileβ
create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, tma_dtype: DType = _Self.dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]
Creates a 2D TMA gather4 descriptor for this operand.
The descriptor views the data as a flat 2D matrix of
[num_kv_rows, tile_width] and is configured for gather4 operations
that load 4 non-contiguous rows per TMA instruction. The box width
is derived from the swizzle mode; for SWIZZLE_NONE it equals
tile_width.
When tma_dtype differs from Self.dtype, the underlying data
pointer is bitcast to tma_dtype at descriptor creation time.
Parameters:
- βtile_width (
Int): Number of elements per row to load (box width) intma_dtypeelements. - βtile_stride (
Int): Row stride in elements in global memory. Defaults totile_width. Use a larger value when the global row is wider than the portion to load. - βswizzle_mode (
TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. Defaults to SWIZZLE_NONE. - βtile_height (
Int): Number of rows in the tile. Must be a multiple of 4. Defaults to 4 for backward compatibility. - βtma_dtype (
DType): The data type used for the TMA descriptor. Defaults toSelf.dtype. When different, the pointer is bitcast. - βl2_promotion (
TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.
Args:
- βctx (
DeviceContext): The CUDA device context used to create the TMA descriptor.
Returns:
TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A TMATensorTile with box width derived from the swizzle mode.
create_rope_gather4_tma_tileβ
create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self: _Self, ctx: DeviceContext) -> TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]
Creates a BF16 gather4 TMA descriptor for the rope portion of the KV cache.
For the per-tensor rope-aware layout each token row is
padded_depth FP8 bytes (content) followed by BF16 rope elements.
This method offsets the base pointer by padded_depth bytes,
reinterprets as BF16, and creates a gather4 TMA descriptor.
Parameters:
- βtile_width (
Int): Number of BF16 elements per row in global memory. - βpadded_depth (
Int): Byte offset from row start to the rope data. - βswizzle_mode (
TensorMapSwizzle): TMA swizzle mode for shared memory access pattern. - βtile_height (
Int): Number of rows in the tile. Must be a multiple of 4. - βl2_promotion (
TensorMapL2Promotion): L2 cache promotion hint for TMA loads. Defaults to NONE.
Args:
- βctx (
DeviceContext): The CUDA device context used to create the TMA descriptor.
Returns:
TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]: A BF16 TMATensorTile configured for gather4.
scales_raw_ptrβ
scales_raw_ptr(self: _Self) -> UnsafePointer[Float32, MutAnyOrigin]
Returns the base pointer to the quantization scales tensor.
Returns a null pointer for operands without quantization support.
Returns:
get_type_nameβ
static get_type_name() -> String
Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer[DType.float32] would return "DeviceBuffer[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive.
Returns:
String: The host type's name.
Provided methodsβ
block_paged_tileβ
block_paged_tile[layout_t: TensorLayout, //, tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, layout_val: layout_t, head_dim_idx: UInt32 = UInt32(0)) -> TileTensor[_Self.dtype, layout_t, ImmutAnyOrigin]
Wraps block_paged_ptr in a TileTensor with the caller's layout.
Returns:
populateβ
populate[BN: Int, pair_cta: Bool = False, is_leader: Bool = True](self: _Self, batch_idx: UInt32, base_kv_row: UInt32) -> PagedRowIndices[BN, _Self.page_size, pair_cta, is_leader]
Populate a full PagedRowIndices[BN, ...] for a BN-row tile.
Returns the precomputed physical row indices for the num_pages
sub-tile pages covering the BN-row range starting at
base_kv_row for batch_idx. Both K's TMA (which may cover only
a subset in pair_cta mode) and V's TMA (which covers the full
range) can then consume the result without any lazy LUT lookup.
Default implementation: scalar loop over num_pages calls to
row_idx. Overrides (e.g. PagedKVCache) replace this with a
single SIMD load from the underlying lookup table.
Returns:
PagedRowIndices[BN, _Self.page_size, pair_cta, is_leader]
copyβ
copy(self: _Self) -> _Self
Explicitly construct a copy of self, a convenience method for Self(copy=self) when the type is inconvenient to write out.
Returns:
_Self: A copy of this value.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!