Skip to main content

Mojo struct

BlockwiseFP8TileWriter

struct BlockwiseFP8TileWriter[c_type: DType, c_smem_layout: Layout, accum_type: DType, accum_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_stages: Int, num_output_warps: Scalar[DType.uindex], c_swizzle: TensorMapSwizzle]

Write register accumulators to GMEM via SMEM and TMA.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

bits

comptime bits = 256

BM

comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)

c_smem_shape0

comptime c_smem_shape0 = c_smem_layout.shape[0].value()

CG1_TMA_BM

comptime CG1_TMA_BM = BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].c_smem_shape0

CG2_TMA_BM

comptime CG2_TMA_BM = c_smem_layout.shape[0].value() if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BM

CTileArray

comptime CTileArray = SMemTileArray[c_type, c_smem_layout, num_output_stages, 128]

data_paths

comptime data_paths = 16

fragment_size

comptime fragment_size = (128 // WARP_SIZE)

fragments_per_stage

comptime fragments_per_stage = (BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size * BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats)

MMA_M

comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)

num_elements

comptime num_elements = accum_layout.shape[1].value()

num_elements_per_load

comptime num_elements_per_load = 8

num_stages

comptime num_stages = accum_layout.shape[0].value()

repeats

comptime repeats = (BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].num_elements // BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size)

stageN

comptime stageN = (BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats * 8)

swizzle

comptime swizzle = make_swizzle[c_type, c_swizzle]()

TMA_BM

comptime TMA_BM = c_smem_layout.shape[0].value() if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else block_tile_shape.__getitem__[3, DType.int64, Int](0) if (eq cta_group._mlir_value, 2) else BlockwiseFP8TileWriter[c_type, c_smem_layout, accum_type, accum_layout, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].CG1_TMA_BM

Methods

write

static write[c_layout: Layout, c_desc_layout: Layout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_layout, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray[c_type, c_smem_layout, num_output_stages, 128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_coord: Tuple[UInt, UInt])

Write accumulated register tiles to GMEM via double-buffered SMEM.

Was this page helpful?