Mojo struct
BlockScaledMmaOp
struct BlockScaledMmaOp[mma_shape: IndexList[3], num_m_mmas: Int, num_n_mmas: Int, num_k_tiles: Int]
Register ownership + block-scaled MFMA execution.
Loads packed uint8 A/B fragments from SMEM or GMEM and executes cdna4_block_scaled_mfma with per-lane E8M0 scale values.
Scale operand model: Each lane holds 32 FP4 elements and one E8M0 scale byte, matching the MX format's per-32-element granularity exactly. For 16x16x128: 64 lanes cover 16 rows x 4 K-groups. lane_row = lane_id % 16 (matrix row) lane_k_group = lane_id / 16 (K-group 0..3)
Scale packing: 4 spatial MMA tiles' scale bytes are packed into one Int32 VGPR β byte i holds the scale for m_mma=i (A) or n_mma=i (B). The MFMA byte-index selector (OP_SEL) picks the correct byte for each MMA tile, so one scale load covers all 4 m_mma or n_mma positions with zero overhead.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
c_frag_sizeβ
comptime c_frag_size = ((BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_M * BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_N) // WARP_SIZE)
mma_frag_width_bytesβ
comptime mma_frag_width_bytes = 16
MMA_Kβ
comptime MMA_K = mma_shape[2]
MMA_Mβ
comptime MMA_M = mma_shape[0]
MMA_Nβ
comptime MMA_N = mma_shape[1]
packed_k_per_mmaβ
comptime packed_k_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 2)
scales_per_mmaβ
comptime scales_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles].MMA_K // 32)
Methodsβ
__init__β
__init__(out self)
accum_tileβ
accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Returns:
ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
load_frag_from_smemβ
load_frag_from_smem[k_tile_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size], b_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=b_smem_warp.linear_idx_type, element_size=b_smem_warp.element_size])
Load MXFP4 A/B fragments from row-major SMEM for k-tile k_tile_idx.
Uses tile to extract the [MMA_M, packed_k_per_mma] sub-tile, vectorize groups 64 bytes into 4 x 16-byte elements, and distribute with col_major[MMA_M, 4] assigns each lane its 16-byte fragment matching the MFMA native lane mapping.
load_scales_from_smemβ
load_scales_from_smem[k_tile_idx: Int](mut self, a_scale_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_scale_smem_warp.linear_idx_type, element_size=a_scale_smem_warp.element_size], b_scale_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=b_scale_smem_warp.linear_idx_type, element_size=b_scale_smem_warp.element_size])
Load packed scale VGPRs for k-tile k_tile_idx from SMEM.
Packs num_m_mmas (A) or num_n_mmas (B) scale bytes into one Int32 each using the same col_major[MMA_M, WARP_SIZE/MMA_M] distribute pattern as load_frag_from_smem. Each lane picks one scale byte via (lane_row, lane_k_group). TileTensor's stride handling means this works for any parent SMEM layout.
The MFMA byte-index selector (a_scale_byte_index=m_mma, b_scale_byte_index=n_mma) picks the correct byte β no shifts or masks at consumption time.
mmaβ
mma[k_tile_idx: Int](self)
Execute block-scaled MFMA for k-tile k_tile_idx.
Bβsrc_a, Aβsrc_b (AMD MFMA convention). The packed scale VGPRs hold one byte per spatial MMA tile. a_scale_byte_index=m selects byte m from _a_scale_packed, b_scale_byte_index=n selects byte n from _b_scale_packed.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!