IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlockScaledMmaOp_PreB

struct BlockScaledMmaOp_PreB[mma_shape: IndexList[3], warp_tile: IndexList[3], num_b_slots: Int = 1]

Per-warp register state + MFMA dispatch for the preb (preshuffled-B, preshuffled-scale) kernel.

warp_tile is the (M, N, K) region this warp computes per outer-K iteration, in the same element units as mma_shape. Per-warp MFMA counts are derived as warp_tile[i] // mma_shape[i].

Asserted in __init__: warp_tile[i] % mma_shape[i] == 0 per axis, and num_k_mmas % 2 == 0 (k_pack=2 cell halves). num_m_mmas / num_n_mmas may be odd; the constructor rotates the scale i32 per CTA so OPSEL keeps the same comptime formula. See module-level comment for the scale-cell byte ordering.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

c_frag_size​

comptime c_frag_size = ((BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_M * BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_N) // WARP_SIZE)

mma_frag_width_bytes​

comptime mma_frag_width_bytes = 16

MMA_K​

comptime MMA_K = mma_shape[2]

MMA_K_BYTES​

comptime MMA_K_BYTES = (BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_K // 2)

MMA_M​

comptime MMA_M = mma_shape[0]

MMA_N​

comptime MMA_N = mma_shape[1]

num_k_mmas​

comptime num_k_mmas = (warp_tile[2] // BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_K)

num_m_mmas​

comptime num_m_mmas = (warp_tile[0] // BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_M)

num_n_mmas​

comptime num_n_mmas = (warp_tile[1] // BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_N)

Methods​

__init__​

def __init__(out self, warp_m_off: Int, warp_n_off: Int)

accum_tile​

def accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Returns:

ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

load_a_frag_from_smem​

def load_a_frag_from_smem[mma_k_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size])

Load A fragment for MFMA-K position mma_k_idx from row-major SMEM.

load_b_frag_preshuffled​

def load_b_frag_preshuffled[mma_k_idx: Int, slot: Int = 0](self, b_loader: PreshuffledBLoader, warp_n_off: Int, k_byte_base: Int)

Load B fragments direct from preshuffled DRAM into b_reg slot slot.

load_a_scales_preshuffled​

def load_a_scales_preshuffled[k_pair: Int](mut self, a_scale_loader: PreshuffledScaleLoader, warp_m_off: Int, k_pair_idx: Int)

Issue per-lane i32 scale loads for A at one k_pair slot.

Caller provides the absolute k_pair_idx (= k_iter * (num_k_mmas / 2) + k_pair); each step advances by 8 K-scales (= 2 MFMAs along K). One i32 per (mi_pair, k_pair) per lane.

load_b_scales_preshuffled​

def load_b_scales_preshuffled[k_pair: Int](mut self, b_scale_loader: PreshuffledScaleLoader, warp_n_off: Int, k_pair_idx: Int)

Mirror of load_a_scales_preshuffled along N.

mma​

def mma[mma_k_idx: Int, slot: Int = 0](self)

Execute block-scaled MFMA at MFMA-K position mma_k_idx using B from slot.

B→src_a, A→src_b (AMD MFMA convention).

OPSEL byte selection picks the right byte from the 2x2 cell: a_byte = (mma_k_idx % 2) * 2 + (m % 2) b_byte = (mma_k_idx % 2) * 2 + (n % 2) Scale dword lives at _*_scale_packed[mn // 2, mma_k_idx // 2].

WM=16 / WN=16 case: every CTA only ever sees m=0 / n=0, so OPSEL is fixed at byte 0 (or 2 for k_pack=1). The constructor records a shrui amount (_a_scale_shift / _b_scale_shift) that rotates the i32 right by 0 or 8 bits so the byte OPSEL selects is the one for this CTA's half of the cell.