Skip to main content

Mojo struct

BlockScaledMmaOp

struct BlockScaledMmaOp[mma_shape: IndexList[3], num_m_mmas: Int, num_n_mmas: Int, num_k_tiles: Int]

Register ownership + block-scaled MFMA execution.

Loads packed uint8 A/B fragments from SMEM or GMEM and executes cdna4_block_scaled_mfma with per-lane E8M0 scale values.

Scale operand model: Each lane holds 32 FP4 elements and one E8M0 scale byte, matching the MX format's per-32-element granularity exactly. For 16x16x128: 64 lanes cover 16 rows x 4 K-groups. lane_row = lane_id % 16 (matrix row) lane_k_group = lane_id / 16 (K-group 0..3)

Scale packing: 4 spatial MMA tiles' scale bytes are packed into one Int32 VGPR β€” byte i holds the scale for m_mma=i (A) or n_mma=i (B). The MFMA byte-index selector (OP_SEL) picks the correct byte for each MMA tile, so one scale load covers all 4 m_mma or n_mma positions with zero overhead.

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

c_frag_size​

comptime c_frag_size = ((BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_M * BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_N) // WARP_SIZE)

mma_frag_width_bytes​

comptime mma_frag_width_bytes = 16

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_M​

comptime MMA_M = mma_shape[0]

MMA_N​

comptime MMA_N = mma_shape[1]

packed_k_per_mma​

comptime packed_k_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 2)

scales_per_mma​

comptime scales_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 32)

Methods​

__init__​

__init__(out self)

accum_tile​

accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Returns:

ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

load_frag_from_smem​

load_frag_from_smem[k_tile_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size], b_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=b_smem_warp.linear_idx_type, element_size=b_smem_warp.element_size])

Load MXFP4 A/B fragments from row-major SMEM for k-tile k_tile_idx.

Uses tile to extract the [MMA_M, packed_k_per_mma] sub-tile, vectorize groups 64 bytes into 4 x 16-byte elements, and distribute with col_major[MMA_M, 4] assigns each lane its 16-byte fragment matching the MFMA native lane mapping.

load_scales_from_smem​

load_scales_from_smem[k_tile_idx: Int](mut self, a_scale_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_scale_smem_warp.linear_idx_type, element_size=a_scale_smem_warp.element_size], b_scale_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=b_scale_smem_warp.linear_idx_type, element_size=b_scale_smem_warp.element_size])

Load packed scale VGPRs for k-tile k_tile_idx from SMEM.

Packs num_m_mmas (A) or num_n_mmas (B) scale bytes into one Int32 each using the same col_major[MMA_M, WARP_SIZE/MMA_M] distribute pattern as load_frag_from_smem. Each lane picks one scale byte via (lane_row, lane_k_group). TileTensor's stride handling means this works for any parent SMEM layout.

The MFMA byte-index selector (a_scale_byte_index=m_mma, b_scale_byte_index=n_mma) picks the correct byte β€” no shifts or masks at consumption time.

mma​

mma[k_tile_idx: Int](self)

Execute block-scaled MFMA for k-tile k_tile_idx.

B→src_a, A→src_b (AMD MFMA convention). The packed scale VGPRs hold one byte per spatial MMA tile. a_scale_byte_index=m selects byte m from _a_scale_packed, b_scale_byte_index=n selects byte n from _b_scale_packed.