IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

RaggedMHAOperand

struct RaggedMHAOperand[origin: ImmutOrigin, cache_origin: ImmutOrigin, //, dtype_: DType, layout: Layout, cache_layout: Layout, scale_dtype_: DType = DType.invalid, scale_layout: Layout = Layout()]

An implementation for ragged LayoutTensor arguments to MHA kernels.

Fields​

  • ​buffer (LayoutTensor[RaggedMHAOperand[dtype_, layout, cache_layout, scale_dtype_, scale_layout].dtype, layout, origin]):
  • ​scale_buffer (LayoutTensor[RaggedMHAOperand[dtype_, layout, cache_layout, scale_dtype_, scale_layout].scale_dtype, scale_layout, ImmutAnyOrigin]):
  • ​cache_row_offsets (LayoutTensor[DType.uint32, cache_layout, cache_origin]):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, MHAOperand, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = RaggedMHAOperand[dtype_, layout, cache_layout, scale_dtype_, scale_layout]

dtype​

comptime dtype = dtype_

page_size​

comptime page_size = 0

quantization_enabled​

comptime quantization_enabled = False

quantization_granularity​

comptime quantization_granularity = 0

scale_dtype​

comptime scale_dtype = scale_dtype_

Methods​

__init__​

def __init__(buffer: LayoutTensor[Self.dtype, layout, origin], cache_row_offsets: LayoutTensor[DType.uint32, cache_layout, cache_origin]) -> Self

def __init__(buffer: LayoutTensor[Self.dtype, layout, origin], scale_buffer: LayoutTensor[Self.scale_dtype, scale_layout, ImmutAnyOrigin], cache_row_offsets: LayoutTensor[DType.uint32, cache_layout, cache_origin]) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

block_paged_ptr​

def block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = UInt32(0)) -> UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.dtype], ImmutAnyOrigin]

scales_block_paged_ptr​

def scales_block_paged_ptr(self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]

Returns:

UnsafePointer[Scalar[Self.scale_dtype], ImmutAnyOrigin]

load_scale​

def load_scale[width: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int) -> SIMD[Self.scale_dtype, width]

Returns:

SIMD[Self.scale_dtype, width]

cache_length​

def cache_length(self, batch_idx: Int) -> Int

Returns:

Int

max_context_length​

def max_context_length(self) -> UInt32

Returns:

UInt32

num_kv_rows​

def num_kv_rows(self) -> Int

Returns the total number of tokens in the ragged buffer.

Returns:

Int

row_idx​

def row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

get_tma_row​

def get_tma_row(self, encoded_index: Int32) -> Int32

Convert an encoded sparse index to a physical TMA row.

Non-paged operand: identity (no paging translation needed).

Returns:

Int32

create_tma_tile​

def create_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[RaggedMHAOperand[dtype_, layout, cache_layout, scale_dtype_, scale_layout].dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: TMATensorTile[Self.dtype, 3, _padded_shape[3, Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Creates a TMA tile for efficient GPU memory transfers.

Returns:

TMATensorTile[Self.dtype, 3, _padded_shape[3, Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, Self.dtype, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_scale_tma_tile​

def create_scale_tma_tile[BMN: Int](self, ctx: DeviceContext, out tma: TMATensorTile[Self.scale_dtype, 2, Index[Int, Int](1, BMN)])

Returns:

TMATensorTile[Self.scale_dtype, 2, Index[Int, Int](1, BMN)]

create_ragged_tma_tile​

def create_ragged_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = padded_depth[RaggedMHAOperand[dtype_, layout, cache_layout, scale_dtype_, scale_layout].dtype, swizzle_mode, depth]()](self, ctx: DeviceContext, out tma: RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK])

Returns:

RaggedTMA3DTile[Self.dtype, swizzle_mode, BM=BN, BN=BK]

create_rope_tma_tile​

def create_rope_tma_tile[swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int, padded_depth: Int](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()])

Not supported for RaggedMHAOperand.

Returns:

TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(BN, 1, BK, __list_literal__=NoneType(None)), swizzle_mode]()]

create_gather4_tma_tile​

def create_gather4_tma_tile[tile_width: Int, tile_stride: Int = tile_width, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, tma_dtype: DType = RaggedMHAOperand[dtype_, layout, cache_layout, scale_dtype_, scale_layout].dtype, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Creates a 2D TMA gather4 descriptor for this ragged operand.

Returns:

TMATensorTile[tma_dtype, 2, IndexList(tile_height, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[tma_dtype, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

create_rope_gather4_tma_tile​

def create_rope_gather4_tma_tile[tile_width: Int, padded_depth: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, tile_height: Int = 4, l2_promotion: TensorMapL2Promotion = TensorMapL2Promotion.NONE](self, ctx: DeviceContext, out tma: TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))])

Not supported for RaggedMHAOperand.

Returns:

TMATensorTile[DType.bfloat16, 2, IndexList(tile_height, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None)), IndexList(1, _gather4_box_width[DType.bfloat16, tile_width, swizzle_mode](), __list_literal__=NoneType(None))]

scales_raw_ptr​

def scales_raw_ptr(self) -> UnsafePointer[Float32, MutAnyOrigin]

Returns a dangling pointer. Ragged operands do not support quantization.

Returns:

UnsafePointer[Float32, MutAnyOrigin]