IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlockwiseFP8TilePayload

struct BlockwiseFP8TilePayload[a_type: DType, b_type: DType, a_scales_type: DType, a_shape: IndexList[2], b_shape: IndexList[2], a_scales_shape: IndexList[2], num_pipeline_stages: Int]

Tile payload for blockwise FP8 matmul (A, B, A-scales tiles).

Unlike BlockScaledTilePayload, this only stores A-scales in SMEM. B-scales are read directly from global memory during the epilogue phase.

Fields​

  • ​a_tiles (BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].ATileArray):
  • ​b_tiles (BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].BTileArray):
  • ​a_scales_tiles (BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].AScalesTileArray):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TilePayload, TrivialRegisterPassable

comptime members​

AScalesTile​

comptime AScalesTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].AScalesTileArray.Tile

AScalesTileArray​

comptime AScalesTileArray = SMemTileArray2DRowMajor[a_scales_type, a_scales_shape[0], a_scales_shape[1], num_pipeline_stages]

ATile​

comptime ATile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].ATileArray.Tile

ATileArray​

comptime ATileArray = SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages]

BTile​

comptime BTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].BTileArray.Tile

BTileArray​

comptime BTileArray = SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages]

Methods​

__init__​

def __init__(a_tiles: SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages], a_scales_tiles: SMemTileArray2DRowMajor[a_scales_type, a_scales_shape[0], a_scales_shape[1], num_pipeline_stages]) -> Self

get_tile​

def get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[a_scales_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]]

Get A, B, A-scales tiles at the specified stage and k-group index.

Returns:

Tuple[TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[a_scales_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]]

get_a_tile​

def get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Self.ATile

Get A tile at the specified stage and k-group index.

Returns:

Self.ATile

get_b_tile​

def get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Self.BTile

Get B tile at the specified stage and k-group index.

Returns:

Self.BTile

get_a_scales_tile​

def get_a_scales_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Self.AScalesTile

Get A-scales tile at the specified stage and k-group index.

Returns:

Self.AScalesTile