Mojo struct
StandardTilePayload
struct StandardTilePayload[a_type: DType, b_type: DType, a_shape: IndexList[2], b_shape: IndexList[2], num_pipeline_stages: Int]
Tile payload for standard matmul (A and B tiles).
Uses explicit dimensions for tile arrays. The tiles are stored as TileTensor with row_major layout. TileTensors are passed directly to TMA/MMA. at TMA/MMA boundaries.
Fields
- a_tiles (
StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].ATileArray): - b_tiles (
StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].BTileArray):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TilePayload,
TrivialRegisterPassable
comptime members
ATile
comptime ATile = StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].ATileArray.Tile
ATileArray
comptime ATileArray = SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages]
BTile
comptime BTile = StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].BTileArray.Tile
BTileArray
comptime BTileArray = SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages]
Methods
__init__
__init__(a_tiles: SMemTileArray2D[a_type, a_shape[0], a_shape[1], num_pipeline_stages], b_tiles: SMemTileArray2D[b_type, b_shape[0], b_shape[1], num_pipeline_stages]) -> Self
get_tile
get_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> Tuple[StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].ATile, StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].BTile]
Get A and B tiles at the specified stage and k-group index.
Returns:
get_a_tile
get_a_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].ATile
Get A tile at the specified stage and k-group index.
Returns:
StandardTilePayload
get_b_tile
get_b_tile[k_group_size: Int](self, stage: UInt32, k_idx: Int) -> StandardTilePayload[a_type, b_type, a_shape, b_shape, num_pipeline_stages].BTile
Get B tile at the specified stage and k-group index.
Returns:
StandardTilePayload
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!