Skip to main content

Mojo struct

Shuffler

struct Shuffler[E: Int]

Host-side MXFP4 preshuffle layouts and helpers for AMD CDNA4.

Parameters​

  • ​E (Int): Number of groups (experts / sort-blocks) the shuffler operates on. Use Shuffler[1] for single-group consumers.

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

b_5d_grouped_layout​

comptime b_5d_grouped_layout[N: Int, K_BYTES: Int] = Layout(Coord(Idx[E](), Coord(Idx[Shuffler[E].MFMA_MN_LANES](), Idx[(N // Shuffler[E].MFMA_MN_LANES)]()), Coord(Idx[Shuffler[E].MFMA_LANE_BYTES](), Idx[Shuffler[E].MFMA_K_LANES](), Idx[(K_BYTES // Shuffler[E].MFMA_K_BYTES)]())), Coord(Idx[(N * K_BYTES)](), Coord(Idx[Shuffler[E].B_STRIDE_MN_LANE](), Idx[((K_BYTES // Shuffler[E].MFMA_K_BYTES) * Shuffler[E].B_STRIDE_K0)]()), Coord(Idx[Shuffler[E].B_STRIDE_LANE_BYTES](), Idx[Shuffler[E].B_STRIDE_K_LANE](), Idx[Shuffler[E].B_STRIDE_K0]())))

Parameters​

  • ​N (Int):
  • ​K_BYTES (Int):

B_STRIDE_K0​

comptime B_STRIDE_K0 = (Shuffler[E].MFMA_K_LANES * Shuffler[E].B_STRIDE_K_LANE)

B_STRIDE_K_LANE​

comptime B_STRIDE_K_LANE = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].MFMA_LANE_BYTES)

B_STRIDE_LANE_BYTES​

comptime B_STRIDE_LANE_BYTES = 1

B_STRIDE_MN_LANE​

comptime B_STRIDE_MN_LANE = Shuffler[E].MFMA_LANE_BYTES

BTileTensor​

comptime BTileTensor[N: Int, K_BYTES: Int] = TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]

Parameters​

  • ​N (Int):
  • ​K_BYTES (Int):

MFMA_K_BYTES​

comptime MFMA_K_BYTES = (Shuffler[E].MFMA_K_LANES * Shuffler[E].MFMA_LANE_BYTES)

MFMA_K_LANES​

comptime MFMA_K_LANES = 4

MFMA_LANE_BYTES​

comptime MFMA_LANE_BYTES = 16

MFMA_MN_LANES​

comptime MFMA_MN_LANES = 16

S_K_BLOCK​

comptime S_K_BLOCK = (Shuffler[E].MFMA_K_LANES * Shuffler[E].S_K_PACK)

S_K_PACK​

comptime S_K_PACK = 2

S_MN_BLOCK​

comptime S_MN_BLOCK = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].S_MN_PACK)

S_MN_PACK​

comptime S_MN_PACK = 2

S_STRIDE_K0​

comptime S_STRIDE_K0 = (Shuffler[E].MFMA_K_LANES * Shuffler[E].S_STRIDE_K_LANE)

S_STRIDE_K_LANE​

comptime S_STRIDE_K_LANE = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].S_STRIDE_MN_LANE)

S_STRIDE_K_PACK​

comptime S_STRIDE_K_PACK = 2

S_STRIDE_MN_LANE​

comptime S_STRIDE_MN_LANE = 4

S_STRIDE_MN_PACK​

comptime S_STRIDE_MN_PACK = 1

scale_4d_grouped_layout​

comptime scale_4d_grouped_layout[MN_padded: Int, K_SCALES: Int] = Layout(Coord(Idx[E](), Coord(Idx[Shuffler[E].MFMA_MN_LANES](), Idx[Shuffler[E].S_MN_PACK](), Idx[(MN_padded // Shuffler[E].S_MN_BLOCK)]()), Coord(Idx[Shuffler[E].MFMA_K_LANES](), Idx[Shuffler[E].S_K_PACK](), Idx[(K_SCALES // Shuffler[E].S_K_BLOCK)]())), Coord(Idx[(MN_padded * K_SCALES)](), Coord(Idx[Shuffler[E].S_STRIDE_MN_LANE](), Idx[Shuffler[E].S_STRIDE_MN_PACK](), Idx[((K_SCALES // Shuffler[E].S_K_BLOCK) * Shuffler[E].S_STRIDE_K0)]()), Coord(Idx[Shuffler[E].S_STRIDE_K_LANE](), Idx[Shuffler[E].S_STRIDE_K_PACK](), Idx[Shuffler[E].S_STRIDE_K0]())))

Parameters​

  • ​MN_padded (Int):
  • ​K_SCALES (Int):

ScaleTileTensor​

comptime ScaleTileTensor[MN: Int, K_SCALES: Int] = TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]

Parameters​

  • ​MN (Int):
  • ​K_SCALES (Int):

Methods​

scale_padded_mn​

static scale_padded_mn(MN: Int) -> Int

Padded MN dim used by the 4D scale layout: MN rounded up to 32.

Returns:

Int

preshuffle_b_5d​

static preshuffle_b_5d[N: Int, K_BYTES: Int, SrcLayout: TensorLayout](src: TileTensor[DType.uint8, SrcLayout, MutAnyOrigin], mut dst: HostBuffer[DType.uint8]) -> TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]

Preshuffle B from [E, N, K_BYTES] row-major to the 5D byte layout.

src is a 3D (E, N, K_BYTES) row-major tensor; dst is a flat host buffer of size E*N*K_BYTES bytes. Returns dst wrapped as a TileTensor with Shuffler.b_5d_grouped_layout[E, N, K_BYTES], ready for buffer_load_dwordx4 direct-to-VGPR reads on the kernel side.

K_BYTES is the FP4-packed K dim (logical K // 2). N must be a multiple of 16, K_BYTES a multiple of 64.

Returns:

TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]

preshuffle_scale_4d​

static preshuffle_scale_4d[MN: Int, K_SCALES: Int, SrcLayout: TensorLayout](src: TileTensor[DType.uint8, SrcLayout, MutAnyOrigin], mut dst: HostBuffer[DType.uint8]) -> TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]

Preshuffle E8M0 scales from [E, MN, K_SCALES] to the 4D layout.

src is a 3D (E, MN, K_SCALES) row-major tensor; dst is a flat host buffer of size E * Shuffler.scale_padded_mn(MN) * K_SCALES bytes. Returns dst wrapped as a TileTensor with the 4D scale layout. K_SCALES is K // 32 (one E8M0 byte per 32 FP4 elements). K_SCALES must be a multiple of 8. Per-group MN is padded up to 32; pad rows are zero-filled.

Returns:

TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]