For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Shuffler
struct Shuffler[E: Int]
MXFP4 preshuffle layouts and helpers for AMD CDNA4.
Parametersβ
- βE (
Int): Number of groups (experts / sort-blocks) the shuffler operates on. UseShuffler[1]for single-group consumers.
Implemented traitsβ
comptime membersβ
b_5d_grouped_layoutβ
comptime b_5d_grouped_layout[N: Int, K_BYTES: Int] = Layout(Coord(ComptimeInt(), Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt(), ComptimeInt())), Coord(ComptimeInt(), Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt(), ComptimeInt())))
Parametersβ
B_STRIDE_K0β
comptime B_STRIDE_K0 = (Shuffler[E].MFMA_K_LANES * Shuffler[E].B_STRIDE_K_LANE)
B_STRIDE_K_LANEβ
comptime B_STRIDE_K_LANE = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].MFMA_LANE_BYTES)
B_STRIDE_LANE_BYTESβ
comptime B_STRIDE_LANE_BYTES = 1
B_STRIDE_MN_LANEβ
comptime B_STRIDE_MN_LANE = Shuffler[E].MFMA_LANE_BYTES
BTileTensorβ
comptime BTileTensor[N: Int, K_BYTES: Int] = TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]
Parametersβ
MFMA_K_BYTESβ
comptime MFMA_K_BYTES = (Shuffler[E].MFMA_K_LANES * Shuffler[E].MFMA_LANE_BYTES)
MFMA_K_LANESβ
comptime MFMA_K_LANES = 4
MFMA_LANE_BYTESβ
comptime MFMA_LANE_BYTES = 16
MFMA_MN_LANESβ
comptime MFMA_MN_LANES = 16
NUM_THREADSβ
comptime NUM_THREADS = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].MFMA_K_LANES)
packed_scale_bytesβ
comptime packed_scale_bytes = 4
S_K_BLOCKβ
comptime S_K_BLOCK = (Shuffler[E].MFMA_K_LANES * Shuffler[E].S_K_PACK)
S_K_PACKβ
comptime S_K_PACK = 2
S_MN_BLOCKβ
comptime S_MN_BLOCK = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].S_MN_PACK)
S_MN_PACKβ
comptime S_MN_PACK = 2
Methodsβ
scale_4d_byte_offβ
static def scale_4d_byte_off[K_SCALES: Int, packed_mode: Bool = False](mn: Int, k_scale: Int) -> Int
Returns:
scale_padded_mnβ
static def scale_padded_mn(MN: Int) -> Int
Padded MN dim used by the 4D scale layout: MN rounded up to 32.
Returns:
preshuffle_b_5dβ
static def preshuffle_b_5d[N: Int, K_BYTES: Int](raw: TileTensor[DType.uint8, linear_idx_type=raw.linear_idx_type, element_size=raw.element_size], dst: TileTensor[DType.uint8, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ctx: DeviceContext)
Launch the GPU MXFP4 B 5D preshuffle.
Invoked eagerly from model weight adapters (one-shot graph) so
the shuffle runs once at session.load instead of the ~hours-long
numpy CPU path. Mirrors block_scales_interleave's origin
handling pattern (accept any origin, cast to any-origin for the
kernel).
Parameters:
- βN (
Int): Per-expert N (must be a multiple of 16). - βK_BYTES (
Int): Per-expert FP4-packed K (must be a multiple of 64).
Args:
- βraw (
TileTensor[DType.uint8, linear_idx_type=raw.linear_idx_type, element_size=raw.element_size]): Row-major source weights[E, N, K_BYTES]. - βdst (
TileTensor[DType.uint8, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size]): Destination buffer (same byte footprint; bytes get written inb_5d_grouped_layoutorder). - βctx (
DeviceContext): AMD device context.
preshuffle_scale_4dβ
static def preshuffle_scale_4d[MN: Int, K_SCALES: Int, SrcLayout: TensorLayout](src: TileTensor[DType.uint8, SrcLayout, MutAnyOrigin], mut dst: HostBuffer[DType.uint8])
preshuffle_grouped_scale_4d_gpuβ
static def preshuffle_grouped_scale_4d_gpu[K_SCALES: Int, SfaRawLayout: TensorLayout, SfaPreLayout: TensorLayout, AOffsetsLayout: TensorLayout](sfa_raw: TileTensor[DType.uint8, SfaRawLayout, linear_idx_type=sfa_raw.linear_idx_type, element_size=sfa_raw.element_size], sfa_pre: TileTensor[DType.uint8, SfaPreLayout, linear_idx_type=sfa_pre.linear_idx_type, element_size=sfa_pre.element_size], a_offsets: TileTensor[DType.uint32, AOffsetsLayout, linear_idx_type=a_offsets.linear_idx_type, element_size=a_offsets.element_size], num_active_experts: Int, max_num_tokens_per_expert: Int, total_wg: Int, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!