IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SMemEpilogueWriter

struct SMemEpilogueWriter[c_type: DType, num_output_stages: Int, //, c_smem_dim0: Int, c_smem_dim1: Int, epilogue_dtype: DType, epc: EpilogueConfig, num_output_warps: Int, c_swizzle: TensorMapSwizzle, simd_size: Int, stage: Int, rep_frag_size: Int, compute_lambda_fn: def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]

SMEM-based epilogue: write accumulators and apply lambda in SMEM.

Fields​

  • ​warp_id (UInt32):
  • ​c_tiles (SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].CTileArray):
  • ​M (UInt32):
  • ​N (UInt32):
  • ​c_row (UInt32):
  • ​c_col (UInt32):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

barrier_threads​

comptime barrier_threads = (num_output_warps * WARP_SIZE)

BM​

comptime BM = epc.BM

BN​

comptime BN = epc.BN

c_smem_layout​

comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)

cta_group​

comptime cta_group = epc.cta_group

CTileArray​

comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]

data_paths​

comptime data_paths = 16

is_lower_frag_required​

comptime is_lower_frag_required = epc.is_lower_frag_required

MMA_M​

comptime MMA_M = epc.MMA_M

MMA_N​

comptime MMA_N = epc.MMA_N

num_stages​

comptime num_stages = epc.num_stages

OutputSyncBarrier​

comptime OutputSyncBarrier = WarpGroupBarrier[SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].barrier_threads]

stage_contiguous_size​

comptime stage_contiguous_size = c_smem_dim1

stageN​

comptime stageN = c_smem_dim0 if SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].transpose_c else c_smem_dim1

swizzle​

comptime swizzle = make_swizzle[c_type, c_swizzle]()

swizzle_width​

comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())

Tile​

comptime Tile = AccumTile[epilogue_dtype, rep_frag_size]

transpose_c​

comptime transpose_c = epc.transpose_c

Methods​

__init__​

def __init__(warp_id: UInt32, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self

Initialize from TileTensor array.

write_tile​

def write_tile(self, tile: AccumTile[epilogue_dtype, rep_frag_size])

Write accumulator tile to SMEM and apply epilogue lambda.