Mojo struct
BlockScaledTilePayload
struct BlockScaledTilePayload[a_type: DType, b_type: DType, sfa_type: DType, sfb_type: DType, a_shape: IndexList[2], b_shape: IndexList[2], sfa_shape: IndexList[2], sfb_shape: IndexList[2], num_pipeline_stages: Int]
Tile payload for block-scaled matmul (A, B, SFA, SFB tiles).
Fieldsβ
- βa_tiles (
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].ATileArray): - βb_tiles (
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].BTileArray): - βsfa_tiles (
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFATileArray): - βsfb_tiles (
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFBTileArray):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TilePayload,
TrivialRegisterPassable
comptime membersβ
ATileβ
comptime ATile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].ATileArray.Tile
ATileArrayβ
comptime ATileArray = SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages]
BTileβ
comptime BTile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].BTileArray.Tile
BTileArrayβ
comptime BTileArray = SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages]
sfa_layoutβ
comptime sfa_layout = Layout(Coord(Coord(Idx[32](), Idx[(sfa_shape[0] // 32)]()), Coord(Coord(Idx[4](), Idx[4]()), Idx[(sfa_shape[1] // 16)]())), Coord(Coord(Idx[16](), Idx[(sfa_shape[1] * 32)]()), Coord(Coord(Idx[1](), Idx[4]()), Idx[512]())))
SFATileβ
comptime SFATile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFATileArray.Tile
SFATileArrayβ
comptime SFATileArray = SMemTileArrayWithLayout[sfa_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].sfa_layout, num_pipeline_stages]
sfb_layoutβ
comptime sfb_layout = Layout(Coord(Coord(Idx[32](), Idx[(sfb_shape[0] // 32)]()), Coord(Coord(Idx[4](), Idx[4]()), Idx[(sfb_shape[1] // 16)]())), Coord(Coord(Idx[16](), Idx[(sfb_shape[1] * 32)]()), Coord(Coord(Idx[1](), Idx[4]()), Idx[512]())))
SFBTileβ
comptime SFBTile = BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFBTileArray.Tile
SFBTileArrayβ
comptime SFBTileArray = SMemTileArrayWithLayout[sfb_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].sfb_layout, num_pipeline_stages]
Methodsβ
__init__β
__init__(a_tiles: SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages], sfa_tiles: SMemTileArrayWithLayout[sfa_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].sfa_layout, num_pipeline_stages], sfb_tiles: SMemTileArrayWithLayout[sfb_type, BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].sfb_layout, num_pipeline_stages]) -> Self
get_tileβ
get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[TileTensor[a_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[b_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[sfa_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED], TileTensor[sfb_type, Layout[*?, *?], MutAnyOrigin, address_space=AddressSpace.SHARED]]
Get A, B, SFA, SFB tiles at the specified stage and k-group index.
Returns:
get_a_tileβ
get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].ATile
Get A tile at the specified stage and k-group index.
Returns:
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].ATile
get_b_tileβ
get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].BTile
Get B tile at the specified stage and k-group index.
Returns:
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].BTile
get_sfa_tileβ
get_sfa_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFATile
Get SFA tile at the specified stage and k-group index.
Returns:
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFATile
get_sfb_tileβ
get_sfb_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFBTile
Get SFB tile at the specified stage and k-group index.
Returns:
BlockScaledTilePayload[a_type, b_type, sfa_type, sfb_type, a_shape, b_shape, sfa_shape, sfb_shape, num_pipeline_stages].SFBTile
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!