IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlockScaledMmaOp

struct BlockScaledMmaOp[mma_shape: IndexList[Int(3)], num_m_mmas: Int, num_n_mmas: Int, num_k_tiles: Int, num_b_slots: Int = Int(1)]

Register ownership + block-scaled MFMA execution.

Loads packed uint8 A/B fragments from SMEM or GMEM and executes cdna4_block_scaled_mfma with per-lane E8M0 scale values.

Scale operand model: Each lane holds 32 FP4 elements and one E8M0 scale byte, matching the MX format's per-32-element granularity exactly. For 16x16x128: 64 lanes cover 16 rows x 4 K-groups. lane_row = lane_id % 16 (matrix row) lane_k_group = lane_id / 16 (K-group 0..3)

Scale packing: 4 spatial MMA tiles' scale bytes are packed into one Int32 VGPR β€” byte i holds the scale for m_mma=i (A) or n_mma=i (B). The MFMA byte-index selector (OP_SEL) picks the correct byte for each MMA tile, so one scale load covers all 4 m_mma or n_mma positions with zero overhead.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

c_frag_size​

comptime c_frag_size = (Int((mul mma_shape[Int(0)], mma_shape[Int(1)])) // _resolve_warp_size())

mma_frag_width_bytes​

comptime mma_frag_width_bytes = Int(16)

MMA_K​

comptime MMA_K = mma_shape[Int(2)]

MMA_M​

comptime MMA_M = mma_shape[Int(0)]

MMA_N​

comptime MMA_N = mma_shape[Int(1)]

packed_k_per_mma​

comptime packed_k_per_mma = (mma_shape[Int(2)] // Int(2))

scales_per_mma​

comptime scales_per_mma = (mma_shape[Int(2)] // Int(32))

smem_swizzle​

comptime smem_swizzle = Optional(Swizzle(Int(3), Int(0), Int(3))) if BlockScaledMmaOp[mma_shape, num_m_mmas, num_n_mmas, num_k_tiles, num_b_slots].use_smem_swizzle else Optional()

use_smem_swizzle​

comptime use_smem_swizzle = ((mma_shape[Int(2)] // Int(2)) == Int(64)) if (num_k_tiles == Int(1)) else (num_k_tiles == Int(1)) and True and (mma_shape[Int(0)] == Int(16))

Methods​

__init__​

def __init__(out self)

accum_tile​

def accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Returns:

ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

load_frag_from_smem​

def load_frag_from_smem[k_tile_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, Storage=a_smem_warp.Storage, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size], b_smem_warp: TileTensor[DType.uint8, Storage=b_smem_warp.Storage, address_space=AddressSpace.SHARED, linear_idx_type=b_smem_warp.linear_idx_type, element_size=b_smem_warp.element_size])

Load MXFP4 A/B fragments from row-major SMEM for k-tile k_tile_idx.

Uses tile to extract the [MMA_M, packed_k_per_mma] sub-tile, vectorize groups 64 bytes into 4 x 16-byte elements, and distribute with col_major[MMA_M, 4] assigns each lane its 16-byte fragment matching the MFMA native lane mapping.

load_a_frag_from_smem​

def load_a_frag_from_smem[k_tile_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, Storage=a_smem_warp.Storage, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size])

A-only variant of load_frag_from_smem for callers that source B elsewhere (e.g. preshuffled DRAM via PreshuffledBLoader).

load_b_frag_preshuffled​

def load_b_frag_preshuffled[k_tile_idx: Int, N: Int, K_BYTES: Int, slot: Int = Int(0)](self, b_loader: PreshuffledBLoader[N, K_BYTES], warp_n_off: Int, k_byte_base: Int)

Load B fragments directly from preshuffled DRAM into b_reg slot slot.

Each lane issues one buffer_load_dwordx4 per (k_tile, n_mma) at the per-lane MFMA mapping (lane%16 β†’ n-row, lane//16 β†’ k-group). The slot parameter selects which b_reg half to write into when num_b_slots > 1 (depth-2 prefetch).

load_scales_from_smem​

def load_scales_from_smem[k_tile_idx: Int](mut self, a_scale_smem_warp: TileTensor[DType.uint8, Storage=a_scale_smem_warp.Storage, address_space=AddressSpace.SHARED, linear_idx_type=a_scale_smem_warp.linear_idx_type, element_size=a_scale_smem_warp.element_size], b_scale_smem_warp: TileTensor[DType.uint8, Storage=b_scale_smem_warp.Storage, address_space=AddressSpace.SHARED, linear_idx_type=b_scale_smem_warp.linear_idx_type, element_size=b_scale_smem_warp.element_size])

Load packed scale VGPRs for k-tile k_tile_idx from SMEM.

Packs num_m_mmas (A) or num_n_mmas (B) scale bytes into one Int32 each using the same col_major[MMA_M, WARP_SIZE/MMA_M] distribute pattern as load_frag_from_smem. Each lane picks one scale byte via (lane_row, lane_k_group). TileTensor's stride handling means this works for any parent SMEM layout.

The MFMA byte-index selector (a_scale_byte_index=m_mma, b_scale_byte_index=n_mma) picks the correct byte β€” no shifts or masks at consumption time.

mma​

def mma[k_tile_idx: Int, slot: Int = Int(0)](self)

Execute block-scaled MFMA for k-tile k_tile_idx using B from slot.

B→src_a, A→src_b (AMD MFMA convention). The packed scale VGPRs hold one byte per spatial MMA tile. a_scale_byte_index=m selects byte m from _a_scale_packed, b_scale_byte_index=n selects byte n from _b_scale_packed.

slot selects which b_reg half to read when num_b_slots > 1.