Skip to main content

Mojo struct

BlockScaledMmaOp

struct BlockScaledMmaOp[mma_shape: IndexList[3], num_m_mmas: Int, num_n_mmas: Int, num_k_tiles: Int]

Register ownership + GMEM loading + block-scaled MFMA execution.

Loads packed uint8 A/B fragments from GMEM and executes cdna4_block_scaled_mfma with per-lane E8M0 scale values.

Scale operand model: Each lane holds 32 FP4 elements and one E8M0 scale byte, matching the MX format's per-32-element granularity exactly. For 16x16x128: 64 lanes cover 16 rows x 4 K-groups. lane_row = lane_id % 16 (matrix row) lane_k_group = lane_id / 16 (K-group 0..3) Each lane loads its own scale: scale[row, base_k + k_group]. The scale byte is placed in byte 0 of the Int32 word.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

c_frag_size

comptime c_frag_size = ((BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_M * BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_N) // WARP_SIZE)

mma_frag_width

comptime mma_frag_width = 16

MMA_K

comptime MMA_K = mma_shape[2]

MMA_M

comptime MMA_M = mma_shape[0]

MMA_N

comptime MMA_N = mma_shape[1]

packed_k_per_mma

comptime packed_k_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 2)

scales_per_mma

comptime scales_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 32)

Methods

__init__

__init__(out self)

accum_tile

accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Returns:

ref

load_frag_from_gmem

load_frag_from_gmem[k_tile_idx: Int, a_stride: Int, b_stride: Int](self, a_gmem_ptr: UnsafePointer[UInt8, a_gmem_ptr.origin], b_gmem_ptr: UnsafePointer[UInt8, b_gmem_ptr.origin], a_m_offset: Int, b_n_offset: Int, k_byte_offset: Int)

Load MXFP4 A/B fragments directly from GMEM (bypasses SMEM).

Uses the row-major MFMA lane mapping from the CDNA4 ISA: lane_row = lane_id % MMA_M lane_chunk = lane_id / MMA_M offset = lane_row * stride + lane_chunk * 16 + k_byte_offset

load_scales

load_scales(self, a_scale_ptr: UnsafePointer[Float8_e8m0fnu, a_scale_ptr.origin], b_scale_ptr: UnsafePointer[Float8_e8m0fnu, b_scale_ptr.origin], a_m_base: Int, b_n_base: Int, k_tile_idx: Int, scale_stride_a: Int, scale_stride_b: Int)

Load per-lane E8M0 scale bytes for the MFMA.

Each lane holds 32 FP4 elements and needs the matching scale. Lane mapping for 16x16x128: lane_row = lane % MMA_M (matrix row, 0..15) lane_k_group = lane / MMA_M (K-group, 0..3) Scale index: scale_ptr[row * stride + base_k + lane_k_group]

This gives full per-32-element MX compliance: 64 lanes x 1 scale each = 16 rows x 4 K-groups = 64 distinct scale values.

Args:

  • a_scale_ptr (UnsafePointer): Base pointer to A scales [M, K//32].
  • b_scale_ptr (UnsafePointer): Base pointer to B scales [N, K//32].
  • a_m_base (Int): Global M-row offset for this warp's first MMA tile.
  • b_n_base (Int): Global N-row offset for this warp's first MMA tile.
  • k_tile_idx (Int): Which k-tile we're processing.
  • scale_stride_a (Int): Row stride of A scales tensor.
  • scale_stride_b (Int): Row stride of B scales tensor.

mma

mma[k_tile_idx: Int](self)

Execute block-scaled MFMA for k-tile k_tile_idx.

The MFMA call swaps B→src_a and A→src_b to match the AMD convention (gpu_mma(c, b, a, c)). The accumulator is stored in row-major order [num_m_mmas, num_n_mmas * c_frag_size] matching the output store's (m_mma, n_mma) indexing.

Loop order: outer=m (output M-tiles), inner=n (output N-tiles). For each (m, n) pair, we read the A fragment for M-tile m and B fragment for N-tile n, then pass B as src_a and A as src_b.

Was this page helpful?