For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
BlockScaledMmaOp
struct BlockScaledMmaOp[mma_shape: IndexList[3], num_m_mmas: Int, num_n_mmas: Int, num_k_tiles: Int, num_b_slots: Int = 1]
Register ownership + block-scaled MFMA execution.
Loads packed uint8 A/B fragments from SMEM or GMEM and executes cdna4_block_scaled_mfma with per-lane E8M0 scale values.
Scale operand model: Each lane holds 32 FP4 elements and one E8M0 scale byte, matching the MX format's per-32-element granularity exactly. For 16x16x128: 64 lanes cover 16 rows x 4 K-groups. lane_row = lane_id % 16 (matrix row) lane_k_group = lane_id / 16 (K-group 0..3)
Scale packing: 4 spatial MMA tiles' scale bytes are packed into one Int32 VGPR β byte i holds the scale for m_mma=i (A) or n_mma=i (B). The MFMA byte-index selector (OP_SEL) picks the correct byte for each MMA tile, so one scale load covers all 4 m_mma or n_mma positions with zero overhead.
Implemented traitsβ
comptime membersβ
c_frag_sizeβ
comptime c_frag_size = ((BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles, num_b_slots].MMA_M * BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles, num_b_slots].MMA_N) // WARP_SIZE)
mma_frag_width_bytesβ
comptime mma_frag_width_bytes = 16
MMA_Kβ
comptime MMA_K = mma_shape[2]
MMA_Mβ
comptime MMA_M = mma_shape[0]
MMA_Nβ
comptime MMA_N = mma_shape[1]
packed_k_per_mmaβ
comptime packed_k_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles, num_b_slots].MMA_K // 2)
scales_per_mmaβ
comptime scales_per_mma = (BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles, num_b_slots].MMA_K // 32)
Methodsβ
__init__β
def __init__(out self)
accum_tileβ
def accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Returns:
ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
load_frag_from_smemβ
def load_frag_from_smem[k_tile_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size], b_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=b_smem_warp.linear_idx_type, element_size=b_smem_warp.element_size])
Load MXFP4 A/B fragments from row-major SMEM for k-tile k_tile_idx.
Uses tile to extract the [MMA_M, packed_k_per_mma] sub-tile, vectorize groups 64 bytes into 4 x 16-byte elements, and distribute with col_major[MMA_M, 4] assigns each lane its 16-byte fragment matching the MFMA native lane mapping.
load_a_frag_from_smemβ
def load_a_frag_from_smem[k_tile_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size])
A-only variant of load_frag_from_smem for callers that source B elsewhere (e.g. preshuffled DRAM via PreshuffledBLoader).
load_b_frag_preshuffledβ
def load_b_frag_preshuffled[k_tile_idx: Int, N: Int, K_BYTES: Int, slot: Int = 0](self, b_loader: PreshuffledBLoader[N, K_BYTES], warp_n_off: Int, k_byte_base: Int)
Load B fragments directly from preshuffled DRAM into b_reg slot slot.
Each lane issues one buffer_load_dwordx4 per (k_tile, n_mma) at the
per-lane MFMA mapping (lane%16 β n-row, lane//16 β k-group). The
slot parameter selects which b_reg half to write into when
num_b_slots > 1 (depth-2 prefetch).
load_scales_from_smemβ
def load_scales_from_smem[k_tile_idx: Int](mut self, a_scale_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_scale_smem_warp.linear_idx_type, element_size=a_scale_smem_warp.element_size], b_scale_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=b_scale_smem_warp.linear_idx_type, element_size=b_scale_smem_warp.element_size])
Load packed scale VGPRs for k-tile k_tile_idx from SMEM.
Packs num_m_mmas (A) or num_n_mmas (B) scale bytes into one Int32 each using the same col_major[MMA_M, WARP_SIZE/MMA_M] distribute pattern as load_frag_from_smem. Each lane picks one scale byte via (lane_row, lane_k_group). TileTensor's stride handling means this works for any parent SMEM layout.
The MFMA byte-index selector (a_scale_byte_index=m_mma, b_scale_byte_index=n_mma) picks the correct byte β no shifts or masks at consumption time.
mmaβ
def mma[k_tile_idx: Int, slot: Int = 0](self)
Execute block-scaled MFMA for k-tile k_tile_idx using B from slot.
Bβsrc_a, Aβsrc_b (AMD MFMA convention). The packed scale VGPRs hold one byte per spatial MMA tile. a_scale_byte_index=m selects byte m from _a_scale_packed, b_scale_byte_index=n selects byte n from _b_scale_packed.
slot selects which b_reg half to read when num_b_slots > 1.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!