For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
LayoutTensorMHAOperand
struct LayoutTensorMHAOperand[origin: ImmutOrigin, scale_origin: ImmutOrigin, //, dtype_: DType, buffer_layout: TensorLayout, scale_dtype_: DType = DType.float32, scale_buffer_layout: TensorLayout = Layout[*?, *?]]
An implementation for contiguous tensor arguments to MHA kernels.
Fieldsβ
- βbuffer (
TileTensor[LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].dtype, buffer_layout, origin]): - βscale_buffer (
TileTensor[LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].scale_dtype, scale_buffer_layout, scale_origin]):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDeletable,
MHAOperand,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout]
dtypeβ
comptime dtype = dtype_
layout_dimβ
comptime layout_dim = buffer_layout.__shape_types[SIMDSize((buffer_layout.rank - Int(1)))].static_value
layout_rankβ
comptime layout_rank = buffer_layout.rank
page_sizeβ
comptime page_size = 0
quantization_enabledβ
comptime quantization_enabled = (scale_buffer_layout.rank != Int(0))
quantization_granularityβ
comptime quantization_granularity = ceildiv(buffer_layout.__shape_types[(add buffer_layout.rank, -1)].static_value, scale_buffer_layout.__shape_types[(add scale_buffer_layout.rank, -1)].static_value if (xor (eq scale_buffer_layout.rank, 0), True) else Int(1))
scale_dimβ
comptime scale_dim = scale_buffer_layout.__shape_types[SIMDSize((scale_buffer_layout.rank - Int(1)))].static_value if (scale_buffer_layout.rank != Int(0)) else Int(1)
scale_dtypeβ
comptime scale_dtype = scale_dtype_
scale_rankβ
comptime scale_rank = scale_buffer_layout.rank
Methodsβ
__init__β
def __init__(buffer: TileTensor[Self.dtype, buffer_layout, origin], scale_buffer: TileTensor[Self.scale_dtype, scale_buffer_layout, scale_origin] = _null_scale_tile_tensor[LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].scale_dtype, scale_buffer_layout]()) -> Self
get_type_nameβ
block_paged_ptrβ
def block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]
Returns:
scales_block_paged_ptrβ
def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]
Returns:
load_scaleβ
def load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[Self.scale_dtype, width]
Returns:
cache_lengthβ
max_context_lengthβ
num_kv_rowsβ
def num_kv_rows(self) -> Int
Returns the total number of virtual rows (batch * seq_len).
Returns:
row_idxβ
def row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32
Returns the row idx when viewing the memory as a matrix.
Returns:
get_tma_rowβ
def get_tma_row(self, encoded_index: Int32) -> Int32
Convert an encoded sparse index to a physical TMA row.
Non-paged operand: identity (no paging translation needed).
Returns:
create_tma_tileβ
def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[dtype_, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])
Creates a TMA tile for efficient GPU memory transfers.
Returns:
create_scale_tma_tileβ
def create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)])
Returns:
TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)]
create_ragged_tma_tileβ
def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[dtype_, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])
Returns:
create_rope_tma_tileβ
def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])
Not supported for LayoutTensorMHAOperand.
Returns:
create_gather4_tma_tileβ
def create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), tma_dtype: DType = LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])
Creates a 2D TMA gather4 descriptor for this contiguous operand.
Returns:
create_rope_gather4_tma_tileβ
def create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])
Not supported for LayoutTensorMHAOperand.
Returns:
scales_raw_ptrβ
def scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]
Returns a dangling pointer. Contiguous operands do not support quantization.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!