Skip to main content

Mojo struct

BlockwiseFP8TileWriter

struct BlockwiseFP8TileWriter[c_type: DType, c_smem_dim0: Int, c_smem_dim1: Int, accum_type: DType, accum_num_stages: Int, accum_num_elements: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_stages: Int, num_output_warps: Scalar[DType.uint], c_swizzle: TensorMapSwizzle]

Write register accumulators to GMEM via SMEM and TMA.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

bits

comptime bits = 256

BM

comptime BM = block_tile_shape.__getitem__[Int](0)

BN

comptime BN = block_tile_shape.__getitem__[Int](1)

c_smem_layout

comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)

CTileArray

comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]

CTileArrayLT

comptime CTileArrayLT = SMemTileArray[c_type, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].c_smem_layout, num_output_stages, 128]

data_paths

comptime data_paths = 16

fragment_size

comptime fragment_size = (128 // WARP_SIZE)

fragments_per_stage

comptime fragments_per_stage = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size * BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats)

MMA_M

comptime MMA_M = mma_shape.__getitem__[Int](0)

MMA_N

comptime MMA_N = mma_shape.__getitem__[Int](1)

num_elements

comptime num_elements = accum_num_elements

num_elements_per_load

comptime num_elements_per_load = 8

num_stages

comptime num_stages = accum_num_stages

repeats

comptime repeats = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].num_elements // BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].fragment_size)

SMEMWriter

comptime SMEMWriter = TMEMToSMemWriter[c_type, accum_type, c_smem_dim0, c_smem_dim1, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BM, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].BN, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].MMA_M, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].MMA_N, BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].stageN, cta_group, Int.__init__[Scalar[DType.uint]](num_output_warps), c_swizzle]

stageN

comptime stageN = (BlockwiseFP8TileWriter[c_type, c_smem_dim0, c_smem_dim1, accum_type, accum_num_stages, accum_num_elements, block_tile_shape=block_tile_shape, mma_shape=mma_shape, is_lower_frag_required=is_lower_frag_required, cta_group=cta_group, num_output_stages=num_output_stages, num_output_warps=num_output_warps, c_swizzle=c_swizzle].repeats * 8)

Methods

write

static write[c_layout: Layout, c_desc_layout: Layout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_coord: Tuple[UInt, UInt])

Write accumulated register tiles to GMEM via double-buffered SMEM.

write_absolute_with_bounds_check

static write_absolute_with_bounds_check[c_tensor_layout: TensorLayout, cluster_size: Int](accum: BlockwiseFP8Accumulator[accum_type, accum_num_stages, accum_num_elements, is_lower_frag_required, block_tile_shape, mma_shape, cluster_size], c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], m_abs: UInt32, n_abs: UInt32, m_end: UInt32, expert_scale: Float32, c_tensor: TileTensor[c_type, c_tensor_layout, MutAnyOrigin])

Write accumulated register tiles to GMEM with bounds checking.

Args:

  • accum (BlockwiseFP8Accumulator): Blockwise FP8 accumulator with upper/lower register tiles.
  • c_tiles (SMemTileArray2DRowMajor): SMEM tile array for C output.
  • m_abs (UInt32): Absolute M coordinate (start of tile in token space).
  • n_abs (UInt32): Absolute N coordinate (start of tile).
  • m_end (UInt32): End offset for bounds checking (exclusive).
  • expert_scale (Float32): Per-expert output scaling factor.
  • c_tensor (TileTensor): C tensor in GMEM (TileTensor for bounds-checked stores).

Was this page helpful?