Mojo struct
BlockwiseFP8TileWriter
struct BlockwiseFP8TileWriter[c_type: DType, c_smem_layout: Layout, accum_type: DType, accum_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_stages: Int, num_output_warps: Scalar[DType.uindex], c_swizzle: TensorMapSwizzle]
Write register accumulators to GMEM via SMEM and TMA.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
bits
comptime bits = 256
BM
comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)
BN
comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)
c_smem_shape0
comptime c_smem_shape0 = c_smem_layout.shape[0].value()
CG1_TMA_BM
comptime CG1_TMA_BM = BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].c_smem_shape0
CG2_TMA_BM
comptime CG2_TMA_BM = c_smem_layout.shape[0].value() if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BM
CTileArray
comptime CTileArray = SMemTileArray[c_type, c_smem_layout, num_output_stages, 128]
data_paths
comptime data_paths = 16
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
fragments_per_stage
comptime fragments_per_stage = (BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size * BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats)
MMA_M
comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)
MMA_N
comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)
num_elements
comptime num_elements = accum_layout.shape[1].value()
num_elements_per_load
comptime num_elements_per_load = 8
num_stages
comptime num_stages = accum_layout.shape[0].value()
repeats
comptime repeats = (BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].num_elements // BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size)
stageN
comptime stageN = (BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats * 8)
swizzle
comptime swizzle = make_swizzle[c_type, c_swizzle]()
TMA_BM
comptime TMA_BM = c_smem_layout.shape[0].value() if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else block_tile_shape.__getitem__[3, DType.int64, Int](0) if (eq cta_group._mlir_value, 2) else BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].CG1_TMA_BM
Methods
write
static write[c_layout: Layout, c_desc_layout: Layout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray[c_type, c_smem_layout, num_output_stages, 128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_coord: Tuple[UInt, UInt])
Write accumulated register tiles to GMEM via double-buffered SMEM.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!