For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
RegTileEpilogue
struct RegTileEpilogue[c_type: DType, chunk_width: Int, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]
Per-lane MFMA epilogue writer with optional fused elementwise lambda.
Encapsulates the per-lane (m_global, n_global) β store / lambda
handoff at the end of an AMD matmul kernel. Each store() call
writes one SIMD chunk of chunk_width columns at a single row β
the natural shape of an AMD MFMA output fragment for one lane.
Per-lane bound handling:
- In-bounds chunk (
n + chunk_width <= n_total): one SIMD store or one lambda call. - Partial chunk (
n < n_total < n + chunk_width): per-element fallback. The SIMD-of-chunk_widthstore would otherwise spill into the next row of the buffer (where stride==N), so we degrade to up tochunk_widthscalar stores or scalar lambda calls, each gated oncol < n_total. This is what makes the writer correct for N-misaligned outputs. - Fully OOB column (
n >= n_total): skip silently.
The caller is responsible for the M bound check before calling
store() β a split-K matmul kernel passes a workspace row that
differs from the logical output row, so the writer cannot derive
a single M bound that applies to both DRAM and lambda modes.
With elementwise_lambda_fn=None writes go to DRAM at
c_ptr + m * row_stride + n directly (no buffer-resource clamp;
the partial-chunk fallback gates by n_total explicitly). With a
lambda set the lambda receives global (m, n) and the SIMD chunk;
DRAM is left untouched. Lambda mode therefore requires the caller
to pass m as the LOGICAL output row β incompatible with a
per-split workspace write. Kernels that use both split-K and a
fused lambda should not set the lambda on the per-split matmul
kernel; instead run a non-fused split-K and apply the lambda in
the reduce kernel that consumes the partials.
Parametersβ
- βc_type (
DType): Output element type. - βchunk_width (
Int): Number of contiguous columns per lane per call. For 16x16x* MFMA on AMD this isMMA_M * MMA_N // WARP_SIZE = 4. For 32x32x* MFMA the natural per-lane fragment is 16 elements but they are spread across non-contiguous columns, so callers should fan out into per-element calls (chunk_width = 1) instead. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional fused epilogue.
Fieldsβ
- βc_ptr_as_int (
Int): Integer address of the destination's base pointer. Stored as Int rather thanUnsafePointerbecause the dst tile's origin may be any mutable origin and the writer is reused across kernels with different origin types. - βrow_stride (
Int): Element stride between consecutive rows of dst. - βn_total (
Int): N dimension of the output, used for the chunk-boundary detection and the per-element OOB gate.
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
Methodsβ
__init__β
__init__(dst: TileTensor[c_type, address_space=dst.address_space, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size]) -> Self
Build from the (mutable) destination DRAM tile.
For non-split-K: dst is the logical output tensor; m in
subsequent store() calls is the logical output row, which
is also the DRAM row.
For split-K matmul kernels: dst is the
(num_splits * M, N) workspace; m in store() is the
workspace row (split_id * M + pid_m * BM + ...). Callers
must keep elementwise_lambda_fn unset in that case β see
the struct doc.
Args:
- βdst (
TileTensor[c_type, address_space=dst.address_space, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size]): Destination DRAM tile (must be mutable).
storeβ
store(self, v: SIMD[c_type, chunk_width], *, m: Int, n: Int)
Write a SIMD chunk at (m, n) of dst.
The caller has already checked the M bound. If the chunk
straddles n_total (a partial block at the column boundary)
the writer falls back to per-element stores or lambda calls.
Args:
- βv (
SIMD[c_type, chunk_width]): SIMD value to write (already cast toSelf.c_type). - βm (
Int): Destination row (DRAM row for split-K workspace, or logical output row for non-split-K / reduce-kernel lambda mode). - βn (
Int): Destination starting column. Caller has typically offset bylane_group * chunk_widthalready.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!