Skip to main content

Mojo struct

SMemEpilogueWriter

@register_passable(trivial) struct SMemEpilogueWriter[c_type: DType, num_output_stages: Int, //, c_smem_dim0: Int, c_smem_dim1: Int, epilogue_dtype: DType, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool, num_stages: Int, simd_size: Int, stage: Int, rep_frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type]

SMEM-based epilogue: write accumulators and apply lambda in SMEM.

Fields

  • warp_id (UInt32):
  • c_tiles (SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].CTileArrayLT):
  • M (UInt32):
  • N (UInt32):
  • c_row (UInt32):
  • c_col (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

barrier_threads

comptime barrier_threads = (num_output_warps * WARP_SIZE)

c_smem_layout

comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)

CTileArrayLT

comptime CTileArrayLT = SMemTileArray[c_type, SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].c_smem_layout, num_output_stages, 128]

data_paths

comptime data_paths = 16

OutputSyncBarrier

comptime OutputSyncBarrier = WarpGroupBarrier[SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].barrier_threads]

stage_contiguous_size

comptime stage_contiguous_size = c_smem_dim1

stageN

comptime stageN = c_smem_dim0 if transpose_c else c_smem_dim1

swizzle

comptime swizzle = make_swizzle[c_type, c_swizzle]()

swizzle_width

comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())

Tile

comptime Tile = AccumTile[epilogue_dtype, rep_frag_size]

Methods

__init__

__init__(warp_id: UInt32, c_tiles: SMemTileArray[c_type, SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].c_smem_layout, num_output_stages, 128], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self

Initialize the SMEM epilogue writer.

__init__(warp_id: UInt32, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self

Initialize from TileTensor array (converts internally).

write_tile

write_tile(self, tile: AccumTile[epilogue_dtype, rep_frag_size])

Write accumulator tile to SMEM and apply epilogue lambda.

Was this page helpful?