Skip to main content

Mojo struct

TileWriter

@register_passable(trivial) struct TileWriter[tma_origin: ImmutOrigin, c_type: DType, c_layout: Layout, c_desc_layout: Layout, //, a_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, num_accum_pipeline_stages: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, c_smem_layout: Layout, num_output_stages: Int, stage_stride_cols: Int, num_output_warps: Int, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True]

Output tile writer for SM100 matmul epilogue.

Stores pointer to TMA descriptor. SMEM tiles passed per-call.

Parameters are passed explicitly to work with both MatmulConfig and BlockScaledMatmulConfig.

The stage_stride_cols parameter must match the value used when constructing the OutputTilePipeline that provides OutputStage instances to the write() method.

Fields

  • c_tma_op (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].TmaOpPtr):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

accum_tile_layout

comptime accum_tile_layout = Layout.row_major(TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].BM, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].stageN)

AccumTmemArray

comptime AccumTmemArray = TmemArrayType[accum_type, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].accum_tile_layout, TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].num_stages, cta_group=cta_group]

bits

comptime bits = 256

BM

comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)

BN

comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)

cg1_num_stages

comptime cg1_num_stages = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].stageN)

cg2_num_stages

comptime cg2_num_stages = (mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_layout.shape[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].N_dim].value()) if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else ((TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].stageN) // 2)

CTileArray

comptime CTileArray = SMemTileArray[c_type, c_smem_layout, num_output_stages, 128]

data_paths

comptime data_paths = 16

epilogue_dtype

comptime epilogue_dtype = c_type if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> a_type, "_mlir_value">>, 80) else DType.float32

fragment_size

comptime fragment_size = (128 // WARP_SIZE)

is_lower_frag_required

comptime is_lower_frag_required = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].BM == 64) if (eq cta_group._mlir_value, 1) else (cta_group == 1).__bool__().__invert__()

MMA_M

comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)

MMA_N

comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)

N_dim

comptime N_dim = 0 if transpose_c else 1

num_stages

comptime num_stages = (mma_shape.__getitem__[3, DType.int64, Int](1) // c_smem_layout.shape[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].N_dim].value()) if (eq mma_shape.__getitem__[3, DType.int64, Int](0)._mlir_value, 256) else ((TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].MMA_N // TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].stageN) // 2) if (eq cta_group._mlir_value, 2) else TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].cg1_num_stages

rep

comptime rep = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].stageN // 8)

rep_frag_size

comptime rep_frag_size = (TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].rep * TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].fragment_size)

Stage

comptime Stage = OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group]

stage_contiguous_size

comptime stage_contiguous_size = c_smem_layout.shape[1].value()

stageN

comptime stageN = c_smem_layout.shape[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].N_dim].value()

TmaOp

comptime TmaOp = TMATensorTile[c_type, c_layout, c_desc_layout]

TmaOpPtr

comptime TmaOpPtr = Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].TmaOp, tma_origin]

Methods

__init__

__init__(c_tma_op: Pointer[TileWriter[a_type, accum_type, block_tile_shape, mma_shape, cta_group, num_accum_pipeline_stages, c_swizzle, transpose_c, c_smem_layout, num_output_stages, stage_stride_cols, num_output_warps, elementwise_compute_lambda_fn, register_based_epilogue].TmaOp, tma_origin]) -> Self

Initialize with pointer to TMA descriptor.

write

write(self, c_tiles: SMemTileArray[c_type, c_smem_layout, num_output_stages, 128], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], tile_coord: Tuple[UInt32, UInt32], shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write accumulated results to global memory.

write_splitk

write_splitk[reduction_layout: Layout](self, c_tiles: SMemTileArray[c_type, c_smem_layout, num_output_stages, 128], stage: OutputStage[num_accum_pipeline_stages, stage_stride_cols, cta_group], scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], reduction_tensor: LayoutTensor[accum_type, reduction_layout, MutAnyOrigin], work_info: WorkInfo, shape: Tuple[UInt32, UInt32], elect_one_warp: Bool)

Write with split-K reduction. Only last split writes to GMEM.

Was this page helpful?