Skip to main content

Mojo struct

BlockScaledTilePipeline

@register_passable(trivial) struct BlockScaledTilePipeline[a_type: DType, b_type: DType, sfa_type: DType, sfb_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, sfa_tile_layout: Layout, sfb_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]

Staged tile storage for A, B, SFA, SFB with producer-consumer sync.

Fields

  • pipeline (BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].Pipeline):
  • a_tiles (BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray):
  • b_tiles (BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray):
  • sfa_tiles (BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFATileArray):
  • sfb_tiles (BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFBTileArray):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

ATile

comptime ATile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray.Tile

ATileArray

comptime ATileArray = SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128]

BarrierArray

comptime BarrierArray = SMemArrayType[SharedMemBarrier, (num_group_stages * 2)]

BTile

comptime BTile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray.Tile

BTileArray

comptime BTileArray = SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]

Pipeline

comptime Pipeline = ProducerConsumerPipeline[num_group_stages]

SFATile

comptime SFATile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFATileArray.Tile

SFATileArray

comptime SFATileArray = SMemTileArrayType[sfa_type, sfa_tile_layout, num_pipeline_stages, 128]

SFBTile

comptime SFBTile = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFBTileArray.Tile

SFBTileArray

comptime SFBTileArray = SMemTileArrayType[sfb_type, sfb_tile_layout, num_pipeline_stages, 128]

Methods

__init__

__init__(barriers: SMemArrayType[SharedMemBarrier, (num_group_stages * 2)], a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128], sfa_tiles: SMemTileArrayType[sfa_type, sfa_tile_layout, num_pipeline_stages, 128], sfb_tiles: SMemTileArrayType[sfb_type, sfb_tile_layout, num_pipeline_stages, 128]) -> Self

Initialize from typed barrier array and all tile arrays.

init_barriers

static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)

Initialize pipeline barriers. Called once by elect_one thread.

producer

producer[origin: MutOrigin](ref [origin] self) -> BlockScaledInputProducer[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]

Get producer view for TMA Load warp.

Returns:

BlockScaledInputProducer

consumer

consumer[origin: MutOrigin](ref [origin] self) -> BlockScaledInputConsumer[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]

Get consumer view for MMA warp.

Returns:

BlockScaledInputConsumer

producer_stage

producer_stage(self) -> UInt32

Returns:

UInt32

consumer_stage

consumer_stage(self) -> UInt32

Returns:

UInt32

Was this page helpful?