Mojo struct
BlockwiseFP8TileWriter
struct BlockwiseFP8TileWriter[c_type: DType, c_smem_dim0: Int, c_smem_dim1: Int, accum_type: DType, accum_num_stages: Int, accum_num_elements: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_stages: Int, num_output_warps: Scalar[DType.uint], c_swizzle: TensorMapSwizzle]
Write register accumulators to GMEM via SMEM and TMA.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
__del__is_trivial
comptime __del__is_trivial = True
bits
comptime bits = 256
BM
comptime BM = block_tile_shape.__getitem__[Int](0)
BN
comptime BN = block_tile_shape.__getitem__[Int](1)
c_smem_layout
comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)
CTileArray
comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]
CTileArrayLT
comptime CTileArrayLT = SMemTileArray[c_type, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].c_smem_layout, num_output_stages, 128]
data_paths
comptime data_paths = 16
fragment_size
comptime fragment_size = (128 // WARP_SIZE)
fragments_per_stage
comptime fragments_per_stage = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size * BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats)
MMA_M
comptime MMA_M = mma_shape.__getitem__[Int](0)
MMA_N
comptime MMA_N = mma_shape.__getitem__[Int](1)
num_elements
comptime num_elements = accum_num_elements
num_elements_per_load
comptime num_elements_per_load = 8
num_stages
comptime num_stages = accum_num_stages
repeats
comptime repeats = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].num_elements // BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size)
SMEMWriter
comptime SMEMWriter = TMEMToSMemWriter[c_type, accum_type, c_smem_dim0, c_smem_dim1, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BM, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BN, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].MMA_M, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].MMA_N, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].stageN, cta_group, Int.__init__[Scalar[DType.uint]](num_output_warps), c_swizzle]
stageN
comptime stageN = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats * 8)
Methods
write
static write[c_layout: Layout, c_desc_layout: Layout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_coord: Tuple[UInt, UInt])
Write accumulated register tiles to GMEM via double-buffered SMEM.
write_absolute_with_bounds_check
static write_absolute_with_bounds_check[c_tensor_layout: TensorLayout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])
Write accumulated register tiles to GMEM with bounds checking.
Args:
- accum (
BlockwiseFP8Accumulator): Blockwise FP8 accumulator with upper/lower register tiles. - c_tiles (
SMemTileArray2DRowMajor): SMEM tile array for C output. - m_abs (
UInt32): Absolute M coordinate (start of tile in token space). - n_abs (
UInt32): Absolute N coordinate (start of tile). - m_end (
UInt32): End offset for bounds checking (exclusive). - expert_scale (
Float32): Per-expert output scaling factor. - c_tensor (
TileTensor): C tensor in GMEM (TileTensor for bounds-checked stores).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!