Skip to main content

Mojo struct

SMemEpilogueWriter

struct SMemEpilogueWriter[c_type: DType, num_output_stages: Int, //, c_smem_dim0: Int, c_smem_dim1: Int, epilogue_dtype: DType, epc: EpilogueConfig, num_output_warps: Int, c_swizzle: TensorMapSwizzle, simd_size: Int, stage: Int, rep_frag_size: Int, compute_lambda_fn: def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]

SMEM-based epilogue: write accumulators and apply lambda in SMEM.

Fields​

  • ​warp_id (UInt32):
  • ​c_tiles (SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].CTileArray):
  • ​M (UInt32):
  • ​N (UInt32):
  • ​c_row (UInt32):
  • ​c_col (UInt32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

barrier_threads​

comptime barrier_threads = (num_output_warps * WARP_SIZE)

BM​

comptime BM = epc.BM

BN​

comptime BN = epc.BN

c_smem_layout​

comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)

cta_group​

comptime cta_group = epc.cta_group

CTileArray​

comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]

data_paths​

comptime data_paths = 16

is_lower_frag_required​

comptime is_lower_frag_required = epc.is_lower_frag_required

MMA_M​

comptime MMA_M = epc.MMA_M

MMA_N​

comptime MMA_N = epc.MMA_N

num_stages​

comptime num_stages = epc.num_stages

OutputSyncBarrier​

comptime OutputSyncBarrier = WarpGroupBarrier[SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].barrier_threads]

stage_contiguous_size​

comptime stage_contiguous_size = c_smem_dim1

stageN​

comptime stageN = c_smem_dim0 if SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].transpose_c else c_smem_dim1

swizzle​

comptime swizzle = make_swizzle[c_type, c_swizzle]()

swizzle_width​

comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())

Tile​

comptime Tile = AccumTile[epilogue_dtype, rep_frag_size]

transpose_c​

comptime transpose_c = epc.transpose_c

Methods​

__init__​

__init__(warp_id: UInt32, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self

Initialize from TileTensor array.

write_tile​

write_tile(self, tile: AccumTile[epilogue_dtype, rep_frag_size])

Write accumulator tile to SMEM and apply epilogue lambda.