IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

LayoutTensorMHAOperand

struct LayoutTensorMHAOperand[origin: ImmutOrigin, scale_origin: ImmutOrigin, //, dtype_: DType, buffer_layout: TensorLayout, scale_dtype_: DType = DType.float32, scale_buffer_layout: TensorLayout = Layout[*?, *?]]

An implementation for contiguous tensor arguments to MHA kernels.

Fields​

  • ​buffer (TileTensor[LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].dtype, buffer_layout, origin]):
  • ​scale_buffer (TileTensor[LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].scale_dtype, scale_buffer_layout, scale_origin]):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHAOperand, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout]

dtype​

comptime dtype = dtype_

layout_dim​

comptime layout_dim = buffer_layout.__shape_types[SIMDSize((buffer_layout.rank - Int(1)))].static_value

layout_rank​

comptime layout_rank = buffer_layout.rank

page_size​

comptime page_size = 0

quantization_enabled​

comptime quantization_enabled = (scale_buffer_layout.rank != Int(0))

quantization_granularity​

comptime quantization_granularity = ceildiv(buffer_layout.__shape_types[(add buffer_layout.rank, -1)].static_value, scale_buffer_layout.__shape_types[(add scale_buffer_layout.rank, -1)].static_value if (xor (eq scale_buffer_layout.rank, 0), True) else Int(1))

scale_dim​

comptime scale_dim = scale_buffer_layout.__shape_types[SIMDSize((scale_buffer_layout.rank - Int(1)))].static_value if (scale_buffer_layout.rank != Int(0)) else Int(1)

scale_dtype​

comptime scale_dtype = scale_dtype_

scale_rank​

comptime scale_rank = scale_buffer_layout.rank

Methods​

__init__​

def __init__(buffer: TileTensor[Self.dtype, buffer_layout, origin], scale_buffer: TileTensor[Self.scale_dtype, scale_buffer_layout, scale_origin] = _null_scale_tile_tensor[LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].scale_dtype, scale_buffer_layout]()) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

block_paged_ptr​

def block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

scales_block_paged_ptr​

def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = Int(0)) -> UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]

load_scale​

def load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[Self.scale_dtype, width]

Returns:

SIMD[Self.scale_dtype, width]

cache_length​

def cache_length(self, batch_idx: Int) -> Int

Returns:

Int

max_context_length​

def max_context_length(self) -> UInt32

Returns:

UInt32

num_kv_rows​

def num_kv_rows(self) -> Int

Returns the total number of virtual rows (batch * seq_len).

Returns:

Int

row_idx​

def row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

get_tma_row​

def get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

Non-paged operand: identity (no paging translation needed).

Returns:

Int32

create_tma_tile​

def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[dtype_, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Creates a TMA tile for efficient GPU memory transfers.

Returns:

TMATensorTile[Self.dtype, Int(3), _padded_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), Self.dtype, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_scale_tma_tile​

def create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)])

Returns:

TMATensorTile[Self.scale_dtype, Int(2), Index[Int, Int](Int(1), BMN)]

create_ragged_tma_tile​

def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[dtype_, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])

Returns:

RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Not supported for LayoutTensorMHAOperand.

Returns:

TMATensorTile[DType.bfloat16, Int(3), _padded_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[Int(3), DType.bfloat16, IndexList(BN, Int(1), BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

def create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), tma_dtype: DType = LayoutTensorMHAOperand[dtype_, buffer_layout, scale_dtype_, scale_buffer_layout].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Creates a 2D TMA gather4 descriptor for this contiguous operand.

Returns:

TMATensorTile[tma_dtype, Int(2), IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

create_rope_gather4_tma_tile​

def create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = Int(4), l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Not supported for LayoutTensorMHAOperand.

Returns:

TMATensorTile[DType.bfloat16, Int(2), IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(Int(1), _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

scales_raw_ptr​

def scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns a dangling pointer. Contiguous operands do not support quantization.

Returns:

UnsafePointer[Float32, MutAnyOrigin]