For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
BlockScaledMmaOp_PreB
struct BlockScaledMmaOp_PreB[mma_shape: IndexList[3], warp_tile: IndexList[3], num_b_slots: Int = 1]
Per-warp register state + MFMA dispatch for the preb (preshuffled-B, preshuffled-scale) kernel.
warp_tile is the (M, N, K) region this warp computes per outer-K
iteration, in the same element units as mma_shape. Per-warp MFMA
counts are derived as warp_tile[i] // mma_shape[i].
Asserted in __init__: warp_tile[i] % mma_shape[i] == 0 per axis,
and num_k_mmas % 2 == 0 (k_pack=2 cell halves). num_m_mmas /
num_n_mmas may be odd; the constructor rotates the scale i32 per
CTA so OPSEL keeps the same comptime formula. See module-level
comment for the scale-cell byte ordering.
Implemented traitsβ
comptime membersβ
c_frag_sizeβ
comptime c_frag_size = ((BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_M * BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_N) // WARP_SIZE)
mma_frag_width_bytesβ
comptime mma_frag_width_bytes = 16
MMA_Kβ
comptime MMA_K = mma_shape[2]
MMA_K_BYTESβ
comptime MMA_K_BYTES = (BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_K // 2)
MMA_Mβ
comptime MMA_M = mma_shape[0]
MMA_Nβ
comptime MMA_N = mma_shape[1]
num_k_mmasβ
comptime num_k_mmas = (warp_tile[2] // BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_K)
num_m_mmasβ
comptime num_m_mmas = (warp_tile[0] // BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_M)
num_n_mmasβ
comptime num_n_mmas = (warp_tile[1] // BlockScaledMmaOp_PreB[mma_shape, warp_tile, num_b_slots].MMA_N)
Methodsβ
__init__β
def __init__(out self, warp_m_off: Int, warp_n_off: Int)
accum_tileβ
def accum_tile(self) -> ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Returns:
ref[self._c_reg] TileTensor[DType.float32, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
load_a_frag_from_smemβ
def load_a_frag_from_smem[mma_k_idx: Int](self, a_smem_warp: TileTensor[DType.uint8, address_space=AddressSpace.SHARED, linear_idx_type=a_smem_warp.linear_idx_type, element_size=a_smem_warp.element_size])
Load A fragment for MFMA-K position mma_k_idx from row-major SMEM.
load_b_frag_preshuffledβ
def load_b_frag_preshuffled[mma_k_idx: Int, slot: Int = 0](self, b_loader: PreshuffledBLoader, warp_n_off: Int, k_byte_base: Int)
Load B fragments direct from preshuffled DRAM into b_reg slot slot.
load_a_scales_preshuffledβ
def load_a_scales_preshuffled[k_pair: Int](mut self, a_scale_loader: PreshuffledScaleLoader, warp_m_off: Int, k_pair_idx: Int)
Issue per-lane i32 scale loads for A at one k_pair slot.
Caller provides the absolute k_pair_idx (= k_iter * (num_k_mmas / 2) + k_pair); each step advances by 8 K-scales
(= 2 MFMAs along K). One i32 per (mi_pair, k_pair) per lane.
load_b_scales_preshuffledβ
def load_b_scales_preshuffled[k_pair: Int](mut self, b_scale_loader: PreshuffledScaleLoader, warp_n_off: Int, k_pair_idx: Int)
Mirror of load_a_scales_preshuffled along N.
mmaβ
def mma[mma_k_idx: Int, slot: Int = 0](self)
Execute block-scaled MFMA at MFMA-K position mma_k_idx using B from slot.
Bβsrc_a, Aβsrc_b (AMD MFMA convention).
OPSEL byte selection picks the right byte from the 2x2 cell:
a_byte = (mma_k_idx % 2) * 2 + (m % 2)
b_byte = (mma_k_idx % 2) * 2 + (n % 2)
Scale dword lives at _*_scale_packed[mn // 2, mma_k_idx // 2].
WM=16 / WN=16 case: every CTA only ever sees m=0 / n=0, so
OPSEL is fixed at byte 0 (or 2 for k_pack=1). The constructor
records a shrui amount (_a_scale_shift / _b_scale_shift)
that rotates the i32 right by 0 or 8 bits so the byte OPSEL
selects is the one for this CTA's half of the cell.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!