Mojo struct
SMemEpilogueWriter
@register_passable(trivial)
struct SMemEpilogueWriter[c_type: DType, c_smem_layout: Layout, num_output_stages: Int, //, epilogue_dtype: DType, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool, num_stages: Int, simd_size: Int, stage: Int, rep_frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type]
Write accumulator tile to SMEM and apply element-wise epilogue lambda.
This writer handles the SMEM-based epilogue path when register_based_epilogue=False. Inferred from c_tiles: c_type, c_smem_layout, num_output_stages. Derived from layout: stageN, stage_contiguous_size.
Fields
- warp_id (
UInt32): - c_tiles (
SMemEpilogueWriter[epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].CTileArray): - M (
UInt32): - N (
UInt32): - c_row (
UInt32): - c_col (
UInt32):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
barrier_threads
comptime barrier_threads = (num_output_warps * WARP_SIZE)
CTileArray
comptime CTileArray = SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128]
data_paths
comptime data_paths = 16
N_dim
comptime N_dim = 0 if transpose_c else 1
stage_contiguous_size
comptime stage_contiguous_size = c_smem_layout.shape[1].value()
stageN
comptime stageN = c_smem_layout.shape[SMemEpilogueWriter[epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].N_dim].value()
swizzle
comptime swizzle = make_swizzle[c_type, c_swizzle]()
swizzle_width
comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())
Tile
comptime Tile = AccumTile[epilogue_dtype, rep_frag_size]
Methods
__init__
__init__(warp_id: UInt32, c_tiles: SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self
Initialize the SMEM epilogue writer.
write_tile
write_tile(self, tile: AccumTile[epilogue_dtype, rep_frag_size])
Write accumulator tile to SMEM and apply epilogue lambda.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!