Mojo struct
Shuffler
struct Shuffler[E: Int]
Host-side MXFP4 preshuffle layouts and helpers for AMD CDNA4.
Parametersβ
- βE (
Int): Number of groups (experts / sort-blocks) the shuffler operates on. UseShuffler[1]for single-group consumers.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
b_5d_grouped_layoutβ
comptime b_5d_grouped_layout[N: Int, K_BYTES: Int] = Layout(Coord(Idx[E](), Coord(Idx[Shuffler[E].MFMA_MN_LANES](), Idx[(N // Shuffler[E].MFMA_MN_LANES)]()), Coord(Idx[Shuffler[E].MFMA_LANE_BYTES](), Idx[Shuffler[E].MFMA_K_LANES](), Idx[(K_BYTES // Shuffler[E].MFMA_K_BYTES)]())), Coord(Idx[(N * K_BYTES)](), Coord(Idx[Shuffler[E].B_STRIDE_MN_LANE](), Idx[((K_BYTES // Shuffler[E].MFMA_K_BYTES) * Shuffler[E].B_STRIDE_K0)]()), Coord(Idx[Shuffler[E].B_STRIDE_LANE_BYTES](), Idx[Shuffler[E].B_STRIDE_K_LANE](), Idx[Shuffler[E].B_STRIDE_K0]())))
Parametersβ
B_STRIDE_K0β
comptime B_STRIDE_K0 = (Shuffler[E].MFMA_K_LANES * Shuffler[E].B_STRIDE_K_LANE)
B_STRIDE_K_LANEβ
comptime B_STRIDE_K_LANE = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].MFMA_LANE_BYTES)
B_STRIDE_LANE_BYTESβ
comptime B_STRIDE_LANE_BYTES = 1
B_STRIDE_MN_LANEβ
comptime B_STRIDE_MN_LANE = Shuffler[E].MFMA_LANE_BYTES
BTileTensorβ
comptime BTileTensor[N: Int, K_BYTES: Int] = TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]
Parametersβ
MFMA_K_BYTESβ
comptime MFMA_K_BYTES = (Shuffler[E].MFMA_K_LANES * Shuffler[E].MFMA_LANE_BYTES)
MFMA_K_LANESβ
comptime MFMA_K_LANES = 4
MFMA_LANE_BYTESβ
comptime MFMA_LANE_BYTES = 16
MFMA_MN_LANESβ
comptime MFMA_MN_LANES = 16
S_K_BLOCKβ
comptime S_K_BLOCK = (Shuffler[E].MFMA_K_LANES * Shuffler[E].S_K_PACK)
S_K_PACKβ
comptime S_K_PACK = 2
S_MN_BLOCKβ
comptime S_MN_BLOCK = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].S_MN_PACK)
S_MN_PACKβ
comptime S_MN_PACK = 2
S_STRIDE_K0β
comptime S_STRIDE_K0 = (Shuffler[E].MFMA_K_LANES * Shuffler[E].S_STRIDE_K_LANE)
S_STRIDE_K_LANEβ
comptime S_STRIDE_K_LANE = (Shuffler[E].MFMA_MN_LANES * Shuffler[E].S_STRIDE_MN_LANE)
S_STRIDE_K_PACKβ
comptime S_STRIDE_K_PACK = 2
S_STRIDE_MN_LANEβ
comptime S_STRIDE_MN_LANE = 4
S_STRIDE_MN_PACKβ
comptime S_STRIDE_MN_PACK = 1
scale_4d_grouped_layoutβ
comptime scale_4d_grouped_layout[MN_padded: Int, K_SCALES: Int] = Layout(Coord(Idx[E](), Coord(Idx[Shuffler[E].MFMA_MN_LANES](), Idx[Shuffler[E].S_MN_PACK](), Idx[(MN_padded // Shuffler[E].S_MN_BLOCK)]()), Coord(Idx[Shuffler[E].MFMA_K_LANES](), Idx[Shuffler[E].S_K_PACK](), Idx[(K_SCALES // Shuffler[E].S_K_BLOCK)]())), Coord(Idx[(MN_padded * K_SCALES)](), Coord(Idx[Shuffler[E].S_STRIDE_MN_LANE](), Idx[Shuffler[E].S_STRIDE_MN_PACK](), Idx[((K_SCALES // Shuffler[E].S_K_BLOCK) * Shuffler[E].S_STRIDE_K0)]()), Coord(Idx[Shuffler[E].S_STRIDE_K_LANE](), Idx[Shuffler[E].S_STRIDE_K_PACK](), Idx[Shuffler[E].S_STRIDE_K0]())))
Parametersβ
ScaleTileTensorβ
comptime ScaleTileTensor[MN: Int, K_SCALES: Int] = TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]
Parametersβ
Methodsβ
scale_padded_mnβ
static scale_padded_mn(MN: Int) -> Int
Padded MN dim used by the 4D scale layout: MN rounded up to 32.
Returns:
preshuffle_b_5dβ
static preshuffle_b_5d[N: Int, K_BYTES: Int, SrcLayout: TensorLayout](src: TileTensor[DType.uint8, SrcLayout, MutAnyOrigin], mut dst: HostBuffer[DType.uint8]) -> TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]
Preshuffle B from [E, N, K_BYTES] row-major to the 5D byte layout.
src is a 3D (E, N, K_BYTES) row-major tensor; dst is a flat
host buffer of size E*N*K_BYTES bytes. Returns dst wrapped as
a TileTensor with Shuffler.b_5d_grouped_layout[E, N, K_BYTES],
ready for buffer_load_dwordx4 direct-to-VGPR reads on the kernel
side.
K_BYTES is the FP4-packed K dim (logical K // 2). N must be a
multiple of 16, K_BYTES a multiple of 64.
Returns:
preshuffle_scale_4dβ
static preshuffle_scale_4d[MN: Int, K_SCALES: Int, SrcLayout: TensorLayout](src: TileTensor[DType.uint8, SrcLayout, MutAnyOrigin], mut dst: HostBuffer[DType.uint8]) -> TileTensor[DType.uint8, Layout[*?, *?], MutAnyOrigin]
Preshuffle E8M0 scales from [E, MN, K_SCALES] to the 4D layout.
src is a 3D (E, MN, K_SCALES) row-major tensor; dst is a flat
host buffer of size E * Shuffler.scale_padded_mn(MN) * K_SCALES
bytes. Returns dst wrapped as a TileTensor with the 4D scale
layout. K_SCALES is K // 32 (one E8M0 byte per 32 FP4 elements).
K_SCALES must be a multiple of 8. Per-group MN is padded up to 32;
pad rows are zero-filled.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!