Mojo struct
BlockScaledMmaOp
struct BlockScaledMmaOp[mma_shape: IndexList[3], num_m_mmas: Int, num_n_mmas: Int, num_k_tiles: Int]
Register ownership + GMEM loading + block-scaled MFMA execution.
Loads packed uint8 A/B fragments from GMEM and executes cdna4_block_scaled_mfma with per-lane E8M0 scale values.
Scale operand model: Each lane holds 32 FP4 elements and one E8M0 scale byte, matching the MX format's per-32-element granularity exactly. For 16x16x128: 64 lanes cover 16 rows x 4 K-groups. lane_row = lane_id % 16 (matrix row) lane_k_group = lane_id / 16 (K-group 0..3) Each lane loads its own scale: scale[row, base_k + k_group]. The scale byte is placed in byte 0 of the Int32 word.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
c_frag_size
comptime c_frag_size = ((BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_M * BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_N) // WARP_SIZE)
mma_frag_width
comptime mma_frag_width = 16
MMA_K
comptime MMA_K = mma_shape[2]
MMA_M
comptime MMA_M = mma_shape[0]
MMA_N
comptime MMA_N = mma_shape[1]
packed_k_per_mma
comptime packed_k_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 2)
scales_per_mma
comptime scales_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 32)
Methods
__init__
__init__(out self)
accum_tile
accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Returns:
ref
load_frag_from_gmem
load_frag_from_gmem[k_tile_idx: Int, a_stride: Int, b_stride: Int](self, a_gmem_ptr: UnsafePointer[UInt8, a_gmem_ptr.origin], b_gmem_ptr: UnsafePointer[UInt8, b_gmem_ptr.origin], a_m_offset: Int, b_n_offset: Int, k_byte_offset: Int)
Load MXFP4 A/B fragments directly from GMEM (bypasses SMEM).
Uses the row-major MFMA lane mapping from the CDNA4 ISA: lane_row = lane_id % MMA_M lane_chunk = lane_id / MMA_M offset = lane_row * stride + lane_chunk * 16 + k_byte_offset
load_scales
load_scales(self, a_scale_ptr: UnsafePointer[Float8_e8m0fnu, a_scale_ptr.origin], b_scale_ptr: UnsafePointer[Float8_e8m0fnu, b_scale_ptr.origin], a_m_base: Int, b_n_base: Int, k_tile_idx: Int, scale_stride_a: Int, scale_stride_b: Int)
Load per-lane E8M0 scale bytes for the MFMA.
Each lane holds 32 FP4 elements and needs the matching scale. Lane mapping for 16x16x128: lane_row = lane % MMA_M (matrix row, 0..15) lane_k_group = lane / MMA_M (K-group, 0..3) Scale index: scale_ptr[row * stride + base_k + lane_k_group]
This gives full per-32-element MX compliance: 64 lanes x 1 scale each = 16 rows x 4 K-groups = 64 distinct scale values.
Args:
- a_scale_ptr (
UnsafePointer): Base pointer to A scales [M, K//32]. - b_scale_ptr (
UnsafePointer): Base pointer to B scales [N, K//32]. - a_m_base (
Int): Global M-row offset for this warp's first MMA tile. - b_n_base (
Int): Global N-row offset for this warp's first MMA tile. - k_tile_idx (
Int): Which k-tile we're processing. - scale_stride_a (
Int): Row stride of A scales tensor. - scale_stride_b (
Int): Row stride of B scales tensor.
mma
mma[k_tile_idx: Int](self)
Execute block-scaled MFMA for k-tile k_tile_idx.
The MFMA call swaps B→src_a and A→src_b to match the AMD convention (gpu_mma(c, b, a, c)). The accumulator is stored in row-major order [num_m_mmas, num_n_mmas * c_frag_size] matching the output store's (m_mma, n_mma) indexing.
Loop order: outer=m (output M-tiles), inner=n (output N-tiles). For each (m, n) pair, we read the A fragment for M-tile m and B fragment for N-tile n, then pass B as src_a and A as src_b.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!