Mojo struct
SMemEpilogueWriter
struct SMemEpilogueWriter[c_type: DType, num_output_stages: Int, //, c_smem_dim0: Int, c_smem_dim1: Int, epilogue_dtype: DType, epc: EpilogueConfig, num_output_warps: Int, c_swizzle: TensorMapSwizzle, simd_size: Int, stage: Int, rep_frag_size: Int, compute_lambda_fn: def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]]
SMEM-based epilogue: write accumulators and apply lambda in SMEM.
Fieldsβ
- βwarp_id (
UInt32): - βc_tiles (
SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].CTileArray): - βM (
UInt32): - βN (
UInt32): - βc_row (
UInt32): - βc_col (
UInt32):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
barrier_threadsβ
comptime barrier_threads = (num_output_warps * WARP_SIZE)
BMβ
comptime BM = epc.BM
BNβ
comptime BN = epc.BN
c_smem_layoutβ
comptime c_smem_layout = Layout.row_major(c_smem_dim0, c_smem_dim1)
cta_groupβ
comptime cta_group = epc.cta_group
CTileArrayβ
comptime CTileArray = SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages]
data_pathsβ
comptime data_paths = 16
is_lower_frag_requiredβ
comptime is_lower_frag_required = epc.is_lower_frag_required
MMA_Mβ
comptime MMA_M = epc.MMA_M
MMA_Nβ
comptime MMA_N = epc.MMA_N
num_stagesβ
comptime num_stages = epc.num_stages
OutputSyncBarrierβ
comptime OutputSyncBarrier = WarpGroupBarrier[SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].barrier_threads]
stage_contiguous_sizeβ
comptime stage_contiguous_size = c_smem_dim1
stageNβ
comptime stageN = c_smem_dim0 if SMemEpilogueWriter[c_smem_dim0, c_smem_dim1, epilogue_dtype, epc, num_output_warps, c_swizzle, simd_size, stage, rep_frag_size, compute_lambda_fn].transpose_c else c_smem_dim1
swizzleβ
comptime swizzle = make_swizzle[c_type, c_swizzle]()
swizzle_widthβ
comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())
Tileβ
comptime Tile = AccumTile[epilogue_dtype, rep_frag_size]
transpose_cβ
comptime transpose_c = epc.transpose_c
Methodsβ
__init__β
__init__(warp_id: UInt32, c_tiles: SMemTileArray2DRowMajor[c_type, c_smem_dim0, c_smem_dim1, num_output_stages], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self
Initialize from TileTensor array.
write_tileβ
write_tile(self, tile: AccumTile[epilogue_dtype, rep_frag_size])
Write accumulator tile to SMEM and apply epilogue lambda.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!