Mojo struct
BlockScaledTilePipeline
@register_passable(trivial)
struct BlockScaledTilePipeline[a_type: DType, b_type: DType, sfa_type: DType, sfb_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, sfa_tile_layout: Layout, sfb_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]
Staged tile storage for A, B, SFA, SFB with producer-consumer sync.
Fields
- pipeline (
BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].Pipeline): - a_tiles (
BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray): - b_tiles (
BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray): - sfa_tiles (
BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFATileArray): - sfb_tiles (
BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFBTileArray):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
ATile
comptime ATile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray.Tile
ATileArray
comptime ATileArray = SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128]
BarrierArray
comptime BarrierArray = SMemArrayType[SharedMemBarrier, (num_group_stages * 2)]
BTile
comptime BTile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray.Tile
BTileArray
comptime BTileArray = SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]
Pipeline
comptime Pipeline = ProducerConsumerPipeline[num_group_stages]
SFATile
comptime SFATile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFATileArray.Tile
SFATileArray
comptime SFATileArray = SMemTileArrayType[sfa_type, sfa_tile_layout, num_pipeline_stages, 128]
SFBTile
comptime SFBTile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFBTileArray.Tile
SFBTileArray
comptime SFBTileArray = SMemTileArrayType[sfb_type, sfb_tile_layout, num_pipeline_stages, 128]
Methods
__init__
__init__(barriers: SMemArrayType[SharedMemBarrier, (num_group_stages * 2)], a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128], sfa_tiles: SMemTileArrayType[sfa_type, sfa_tile_layout, num_pipeline_stages, 128], sfb_tiles: SMemTileArrayType[sfb_type, sfb_tile_layout, num_pipeline_stages, 128]) -> Self
Initialize from typed barrier array and all tile arrays.
init_barriers
static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)
Initialize pipeline barriers. Called once by elect_one thread.
producer
producer[origin: MutOrigin](ref [origin] self) -> BlockScaledInputProducer[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]
Get producer view for TMA Load warp.
Returns:
BlockScaledInputProducer
consumer
consumer[origin: MutOrigin](ref [origin] self) -> BlockScaledInputConsumer[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]
Get consumer view for MMA warp.
Returns:
BlockScaledInputConsumer
producer_stage
consumer_stage
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!