Mojo struct
BlockScaledProducerStage
@register_passable(trivial)
struct BlockScaledProducerStage[origin: MutOrigin, a_type: DType, b_type: DType, sfa_type: DType, sfb_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, sfa_tile_layout: Layout, sfb_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]
Context manager for producer tile access with scaling factors.
Fields
- pipeline_ptr (
Pointer[BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
ATile
comptime ATile = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.ATile
ATileArray
comptime ATileArray = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.ATileArray
BTile
comptime BTile = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.BTile
BTileArray
comptime BTileArray = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.BTileArray
SFATile
comptime SFATile = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.SFATile
SFATileArray
comptime SFATileArray = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.SFATileArray
SFBTile
comptime SFBTile = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.SFBTile
SFBTileArray
comptime SFBTileArray = BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType.SFBTileArray
TilePipelineType
comptime TilePipelineType = BlockScaledTilePipeline[a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]
Methods
__init__
__init__(pipeline_ptr: Pointer[BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].TilePipelineType, origin], stage: UInt32, barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128], sfa_tiles: SMemTileArrayType[sfa_type, sfa_tile_layout, num_pipeline_stages, 128], sfb_tiles: SMemTileArrayType[sfb_type, sfb_tile_layout, num_pipeline_stages, 128]) -> Self
__enter__
__enter__(mut self) -> Self
__exit__
__exit__(mut self)
get_tile
get_tile(self, k_idx: Int) -> Tuple[BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATile, BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTile]
Get A and B tiles at the specified k-group index.
Returns:
get_sf_tile
get_sf_tile(self, k_idx: Int) -> Tuple[BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFATile, BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFBTile]
Get A and B scaling factor tiles at the specified k-group index.
Returns:
get_a_tile
get_a_tile(self, k_idx: Int) -> BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATile
Get A tile at the specified k-group index.
Returns:
BlockScaledProducerStage
get_b_tile
get_b_tile(self, k_idx: Int) -> BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTile
Get B tile at the specified k-group index.
Returns:
BlockScaledProducerStage
get_sfa_tile
get_sfa_tile(self, k_idx: Int) -> BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFATile
Get A scaling factor tile at the specified k-group index.
Returns:
BlockScaledProducerStage
get_sfb_tile
get_sfb_tile(self, k_idx: Int) -> BlockScaledProducerStage[origin, a_type, b_type, sfa_type, sfb_type, a_tile_layout, b_tile_layout, sfa_tile_layout, sfb_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].SFBTile
Get B scaling factor tile at the specified k-group index.
Returns:
BlockScaledProducerStage
expect_bytes
expect_bytes(self, num_bytes: Int)
Set expected bytes on the barrier for TMA loads.
barrier
barrier(self) -> MbarPtr
Get the barrier pointer for TMA multicast loads.
Returns:
MbarPtr
stage
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!