IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Shuffler

struct Shuffler[E: Int]

MXFP4 preshuffle layouts and helpers for AMD CDNA4.

Parameters​

  • ​E (Int): Number of groups (experts / sort-blocks) the shuffler operates on. Use Shuffler[1] for single-group consumers.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

b_5d_grouped_layout​

comptime b_5d_grouped_layout[N: Int, K_BYTES: Int] = Layout(Coord(ComptimeInt(), Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt(), ComptimeInt())), Coord(ComptimeInt(), Coord(ComptimeInt(), ComptimeInt()), Coord(ComptimeInt(), ComptimeInt(), ComptimeInt())))

Parameters​

  • ​N (Int):
  • ​K_BYTES (Int):

B_STRIDE_K0​

comptime B_STRIDE_K0 = (Shuffler[E].MFMA_K_LANES * Shuffler[E].B_STRIDE_K_LANE)

B_STRIDE_K_LANE​

comptime B_STRIDE_K_LANE = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].MFMA_LANE_BYTES)

B_STRIDE_LANE_BYTES​

comptime B_STRIDE_LANE_BYTES = 1

B_STRIDE_MN_LANE​

comptime B_STRIDE_MN_LANE = Shuffler[E].MFMA_LANE_BYTES

BTileTensor​

comptime BTileTensor[N: Int, K_BYTES: Int] = TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]

Parameters​

  • ​N (Int):
  • ​K_BYTES (Int):

MFMA_K_BYTES​

comptime MFMA_K_BYTES = (Shuffler[E].MFMA_K_LANES * Shuffler[E].MFMA_LANE_BYTES)

MFMA_K_LANES​

comptime MFMA_K_LANES = 4

MFMA_LANE_BYTES​

comptime MFMA_LANE_BYTES = 16

MFMA_MN_LANES​

comptime MFMA_MN_LANES = 16

NUM_THREADS​

comptime NUM_THREADS = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].MFMA_K_LANES)

packed_scale_bytes​

comptime packed_scale_bytes = 4

S_K_BLOCK​

comptime S_K_BLOCK = (Shuffler[E].MFMA_K_LANES * Shuffler[E].S_K_PACK)

S_K_PACK​

comptime S_K_PACK = 2

S_MN_BLOCK​

comptime S_MN_BLOCK = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].S_MN_PACK)

S_MN_PACK​

comptime S_MN_PACK = 2

Methods​

scale_4d_byte_off​

static def scale_4d_byte_off[K_SCALES: Int, packed_mode: Bool = False](mn: Int, k_scale: Int) -> Int

Returns:

Int

scale_padded_mn​

static def scale_padded_mn(MN: Int) -> Int

Padded MN dim used by the 4D scale layout: MN rounded up to 32.

Returns:

Int

preshuffle_b_5d​

static def preshuffle_b_5d[N: Int, K_BYTES: Int](raw: TileTensor[DType.uint8, linear_idx_type=raw.linear_idx_type, element_size=raw.element_size], dst: TileTensor[DType.uint8, linear_idx_type=dst.linear_idx_type, element_size=dst.element_size], ctx: DeviceContext)

Launch the GPU MXFP4 B 5D preshuffle.

Invoked eagerly from model weight adapters (one-shot graph) so the shuffle runs once at session.load instead of the ~hours-long numpy CPU path. Mirrors block_scales_interleave's origin handling pattern (accept any origin, cast to any-origin for the kernel).

Parameters:

  • ​N (Int): Per-expert N (must be a multiple of 16).
  • ​K_BYTES (Int): Per-expert FP4-packed K (must be a multiple of 64).

Args:

preshuffle_scale_4d​

static def preshuffle_scale_4d[MN: Int, K_SCALES: Int, SrcLayout: TensorLayout](src: TileTensor[DType.uint8, SrcLayout, MutAnyOrigin], mut dst: HostBuffer[DType.uint8])

preshuffle_grouped_scale_4d_gpu​

static def preshuffle_grouped_scale_4d_gpu[K_SCALES: Int, SfaRawLayout: TensorLayout, SfaPreLayout: TensorLayout, AOffsetsLayout: TensorLayout](sfa_raw: TileTensor[DType.uint8, SfaRawLayout, linear_idx_type=sfa_raw.linear_idx_type, element_size=sfa_raw.element_size], sfa_pre: TileTensor[DType.uint8, SfaPreLayout, linear_idx_type=sfa_pre.linear_idx_type, element_size=sfa_pre.element_size], a_offsets: TileTensor[DType.uint32, AOffsetsLayout, linear_idx_type=a_offsets.linear_idx_type, element_size=a_offsets.element_size], num_active_experts: Int, max_num_tokens_per_expert: Int, total_wg: Int, ctx: DeviceContext)