Skip to main content

Mojo struct

BlockScaledTilePayload

@register_passable(trivial) struct BlockScaledTilePayload[a_type: DType, b_type: DType, sfa_type: DType, sfb_type: DType, a_dim0: Int, a_dim1: Int, b_dim0: Int, b_dim1: Int, sfa_dim0: Int, sfa_dim1: Int, sfb_dim0: Int, sfb_dim1: Int, num_pipeline_stages: Int]

Tile payload for block-scaled matmul (A, B, SFA, SFB tiles).

Fields

  • a_tiles (BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATileArray):
  • b_tiles (BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTileArray):
  • sfa_tiles (BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATileArray):
  • sfb_tiles (BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTileArray):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TilePayload, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

ATile

comptime ATile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATileArray.Tile

ATileArray

comptime ATileArray = SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages]

BTile

comptime BTile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTileArray.Tile

BTileArray

comptime BTileArray = SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages]

sfa_layout

comptime sfa_layout = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[32](), Idx[(sfa_dim0 // 32)]())), Coord(VariadicPack(Coord(VariadicPack(Idx[4](), Idx[4]())), Idx[(sfa_dim1 // 16)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[16](), Idx[(sfa_dim1 * 32)]())), Coord(VariadicPack(Coord(VariadicPack(Idx[1](), Idx[4]())), Idx[512]())))))

SFATile

comptime SFATile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATileArray.Tile

SFATileArray

comptime SFATileArray = SMemTileArrayWithLayout[sfa_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfa_layout, num_pipeline_stages]

sfb_layout

comptime sfb_layout = Layout(Coord(VariadicPack(Coord(VariadicPack(Idx[32](), Idx[(sfb_dim0 // 32)]())), Coord(VariadicPack(Coord(VariadicPack(Idx[4](), Idx[4]())), Idx[(sfb_dim1 // 16)]())))), Coord(VariadicPack(Coord(VariadicPack(Idx[16](), Idx[(sfb_dim1 * 32)]())), Coord(VariadicPack(Coord(VariadicPack(Idx[1](), Idx[4]())), Idx[512]())))))

SFBTile

comptime SFBTile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTileArray.Tile

SFBTileArray

comptime SFBTileArray = SMemTileArrayWithLayout[sfb_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfb_layout, num_pipeline_stages]

Methods

__init__

__init__(a_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, num_pipeline_stages], sfa_tiles: SMemTileArrayWithLayout[sfa_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfa_layout, num_pipeline_stages], sfb_tiles: SMemTileArrayWithLayout[sfb_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].sfb_layout, num_pipeline_stages]) -> Self

get_tile

get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATile, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTile, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATile, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTile]

Get A, B, SFA, SFB tiles at the specified stage and k-group index.

Returns:

Tuple

get_a_tile

get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].ATile

Get A tile at the specified stage and k-group index.

Returns:

BlockScaledTilePayload

get_b_tile

get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].BTile

Get B tile at the specified stage and k-group index.

Returns:

BlockScaledTilePayload

get_sfa_tile

get_sfa_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFATile

Get SFA tile at the specified stage and k-group index.

Returns:

BlockScaledTilePayload

get_sfb_tile

get_sfb_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_dim0, a_dim1, b_dim0, b_dim1, sfa_dim0, sfa_dim1, sfb_dim0, sfb_dim1, num_pipeline_stages].SFBTile

Get SFB tile at the specified stage and k-group index.

Returns:

BlockScaledTilePayload

Was this page helpful?