Mojo struct
BlockwiseFP8TilePayload
struct BlockwiseFP8TilePayload[a_type: DType, b_type: DType, a_scales_type: DType, a_shape: IndexList[2], b_shape: IndexList[2], a_scales_shape: IndexList[2], num_pipeline_stages: Int]
Tile payload for blockwise FP8 matmul (A, B, A-scales tiles).
Unlike BlockScaledTilePayload, this only stores A-scales in SMEM. B-scales are read directly from global memory during the epilogue phase.
Fieldsβ
- βa_tiles (
BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].ATileArray): - βb_tiles (
BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].BTileArray): - βa_scales_tiles (
BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].AScalesTileArray):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TilePayload,
TrivialRegisterPassable
comptime membersβ
AScalesTileβ
comptime AScalesTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].AScalesTileArray.Tile
AScalesTileArrayβ
comptime AScalesTileArray = SMemTileArray2DRowMajor[a_scales_type, a_scales_shape[0], a_scales_shape[1], num_pipeline_stages]
ATileβ
comptime ATile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].ATileArray.Tile
ATileArrayβ
comptime ATileArray = SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages]
BTileβ
comptime BTile = BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].BTileArray.Tile
BTileArrayβ
comptime BTileArray = SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages]
Methodsβ
__init__β
__init__(a_tiles: SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages], a_scales_tiles: SMemTileArray2DRowMajor[a_scales_type, a_scales_shape[0], a_scales_shape[1], num_pipeline_stages]) -> Self
get_tileβ
get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[a_scales_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]]
Get A, B, A-scales tiles at the specified stage and k-group index.
Returns:
get_a_tileβ
get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].ATile
Get A tile at the specified stage and k-group index.
Returns:
BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].ATile
get_b_tileβ
get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].BTile
Get B tile at the specified stage and k-group index.
Returns:
BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].BTile
get_a_scales_tileβ
get_a_scales_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].AScalesTile
Get A-scales tile at the specified stage and k-group index.
Returns:
BlockwiseFP8TilePayload[a_type, b_type, a_scales_type, a_shape, b_shape, a_scales_shape, num_pipeline_stages].AScalesTile
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!