Skip to main content

Mojo struct

BlockwiseFP8TilePayload

@register_passable(trivial) struct BlockwiseFP8TilePayload[a_type: DType, b_type: DType, a_scales_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, a_scales_tile_layout: Layout, num_pipeline_stages: Int]

Tile payload for blockwise FP8 matmul (A, B, A-scales tiles).

Unlike BlockScaledTilePayload, this only stores A-scales in SMEM. B-scales are read directly from global memory during the epilogue phase.

Fields

  • a_tiles (BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].ATileArray):
  • b_tiles (BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].BTileArray):
  • a_scales_tiles (BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].AScalesTileArray):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, TilePayload

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

AScalesTile

comptime AScalesTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].AScalesTileArray.Tile

AScalesTileArray

comptime AScalesTileArray = SMemTileArray[a_scales_type, a_scales_tile_layout, num_pipeline_stages, 128]

ATile

comptime ATile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].ATileArray.Tile

ATileArray

comptime ATileArray = SMemTileArray[a_type, a_tile_layout, num_pipeline_stages, 128]

BTile

comptime BTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].BTileArray.Tile

BTileArray

comptime BTileArray = SMemTileArray[b_type, b_tile_layout, num_pipeline_stages, 128]

Methods

__init__

__init__(a_tiles: SMemTileArray[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArray[b_type, b_tile_layout, num_pipeline_stages, 128], a_scales_tiles: SMemTileArray[a_scales_type, a_scales_tile_layout, num_pipeline_stages, 128]) -> Self

get_tile

get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].ATile, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].BTile, BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].AScalesTile]

Get A, B, A-scales tiles at the specified stage and k-group index.

Returns:

Tuple

get_a_tile

get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].ATile

Get A tile at the specified stage and k-group index.

Returns:

BlockwiseFP8TilePayload

get_b_tile

get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].BTile

Get B tile at the specified stage and k-group index.

Returns:

BlockwiseFP8TilePayload

get_a_scales_tile

get_a_scales_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_tile_layout, b_tile_layout, a_scales_tile_layout, num_pipeline_stages].AScalesTile

Get A-scales tile at the specified stage and k-group index.

Returns:

BlockwiseFP8TilePayload

Was this page helpful?