IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

PipelineConfig

struct PipelineConfig

Declarative pipeline strategy.

Captures all the knowledge needed to transform a logical loop body into a pipelined schedule: buffer depth, prefetch distance, loop-carried edges, MMA block sizing, and hardware model.

Platform-specific factories (e.g., mi355x_double_buffer() in amd_target.mojo) provide tuned configurations.

Fields​

  • ​depth (Int):
  • ​prefetch (Int):
  • ​drain_passes (Int):
  • ​prologue_fill (Int):
  • ​loop_carried (LoopCarriedSpec):
  • ​block_sizing (BlockSizing):
  • ​frag_order (FragOrder):
  • ​m_mmas (Int):
  • ​n_mmas (Int):
  • ​num_partitions (Int):
  • ​mma_serial (Bool):
  • ​mma_latency (Int):
  • ​vm_per_load_a (Int):
  • ​vm_per_load_b (Int):
  • ​lgkm_per_load_a (Int): Kernel-geometry-derived lgkmcnt entries per channel-A frag-load. 0 falls back to ScheduleConfig.lgkm_per_load_a.
  • ​lgkm_per_load_b (Int): Kernel-geometry-derived lgkmcnt entries per channel-B frag-load. 0 falls back to ScheduleConfig.lgkm_per_load_b.
  • ​ch0_match_field (Int):
  • ​ch1_match_field (Int):
  • ​warp_stagger (WarpStaggerRule):
  • ​cross_stage_rotation (Bool): True when the schedule intentionally pre-loads the next K-partition's leading-quadrant fragments from the other SMEM stage (4-wave's mini-3/4 register rotation). Relaxes the "fragment loads in half h must use stage h" invariant in program_builder._verify_stage_consistency β€” same-stage and cross-stage frags coexist by design when this is True. Default False keeps the strict check active for ping-pong and other schedules that don't rotate.

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

Methods​

__init__​

__init__(out self, *, depth: Int, prefetch: Int, drain_passes: Int, prologue_fill: Int, loop_carried: LoopCarriedSpec, block_sizing: BlockSizing, frag_order: FragOrder, m_mmas: Int, n_mmas: Int, num_partitions: Int, mma_serial: Bool, mma_latency: Int, vm_per_load_a: Int, vm_per_load_b: Int, ch0_match_field: Int, ch1_match_field: Int, warp_stagger: WarpStaggerRule, lgkm_per_load_a: Int = 0, lgkm_per_load_b: Int = 0, cross_stage_rotation: Bool = False)

Constructs a PipelineConfig from individual fields.

lgkm_per_load_a / lgkm_per_load_b are optional kernel-geometry defaults; pass 0 to fall back to ScheduleConfig.lgkm_per_load_*. See the field-level docstrings on PipelineConfig for per-field meanings.

Args:

  • ​depth (Int): Pipeline buffer depth (1 = single, 2 = double).
  • ​prefetch (Int): DRAM-prefetch distance, typically 1.
  • ​drain_passes (Int): Epilogue drain iteration count.
  • ​prologue_fill (Int): Extra load iterations in the prologue.
  • ​loop_carried (LoopCarriedSpec): Ops crossing loop iteration boundaries.
  • ​block_sizing (BlockSizing): MMA block op targets.
  • ​frag_order (FragOrder): Fragment ordering within a block.
  • ​m_mmas (Int): M-dimension MMA tile count.
  • ​n_mmas (Int): N-dimension MMA tile count.
  • ​num_partitions (Int): Number of warp groups.
  • ​mma_serial (Bool): Whether the MMA unit is serial.
  • ​mma_latency (Int): MMA latency in cycles.
  • ​vm_per_load_a (Int): vmcnt ops per channel-A global load.
  • ​vm_per_load_b (Int): vmcnt ops per channel-B global load.
  • ​ch0_match_field (Int): Channel-0 register-flow match field.
  • ​ch1_match_field (Int): Channel-1 register-flow match field.
  • ​warp_stagger (WarpStaggerRule): Warp-group stagger configuration.
  • ​lgkm_per_load_a (Int): lgkmcnt ops per channel-A frag-load (0 = fall back to ScheduleConfig).
  • ​lgkm_per_load_b (Int): lgkmcnt ops per channel-B frag-load (0 = fall back to ScheduleConfig).
  • ​cross_stage_rotation (Bool): Set to True for schedules that intentionally pre-load the next K-partition's leading-quadrant fragments from the cross stage (4-wave's mini-3/4 rotation). Relaxes the strict stage-consistency invariant in _verify_stage_consistency.

mmas_per_partition​

mmas_per_partition(self) -> Int

MMA ops per warp group: m_mmas Γ— n_mmas.

Returns:

Int

globals_per_partition​

globals_per_partition(self) -> Int

Global loads per warp group: m_mmas + n_mmas (A + B tiles).

Returns:

Int

frags_per_partition​

frags_per_partition(self) -> Int

Fragment loads per warp group: m_mmas + n_mmas (A + B frags).

Returns:

Int

ops_per_partition​

ops_per_partition(self) -> Int

Total ops per warp group.

Returns:

Int

total_ops​

total_ops(self) -> Int

Total ops across all warp groups.

Returns:

Int

blocks_per_partition​

blocks_per_partition(self) -> Int

MMA blocks per warp group (one block per MMA).

Returns:

Int

total_blocks​

total_blocks(self) -> Int

Total MMA blocks.

Returns:

Int

compute_match_key​

compute_match_key(self, compute_op: OpDesc, channel: Int) -> Int

Extract the compute field that a fragment on channel matches.

For channel 0 (A): returns compute.stage (row). For channel 1 (B): returns compute.subtile (col).

Returns:

Int

vm_per_channel​

vm_per_channel(self, channel: Int) -> Int

Return vmcnt cost for a global load on the given channel.

Returns:

Int

lgkm_per_channel​

lgkm_per_channel(self, channel: Int) -> Int

Returns the lgkmcnt cost for a fragment load on the given channel.

Reads from lgkm_per_load_a/b set on the config (typically populated from KernelGeometry). Returns 0 if unset; callers should fall back to ScheduleConfig.lgkm_per_load_* for legacy schedules.

Args:

  • ​channel (Int): 0 for channel A, anything else for channel B.

Returns:

Int: lgkmcnt entries per fragment load on channel, or 0 if unset.

total_edges​

total_edges(self) -> Int

Total dependency edges for double-buffer pipeline.

Four phases of edges connect ops within and across iterations:

  • reg_flow: fragment_load β†’ compute (register FLOW). Each half has 2 channels (A, B), each channel's frag feeds m*n compute ops.
  • accum: compute β†’ compute (accumulator forwarding). m*n accumulator tiles forwarded between halves, twice (half0β†’half1 at d=0, half1β†’half0 at d=1).
  • lds_flow: global_load β†’ fragment_load (LDS FLOW). Each half has g global loads, each feeds one frag per buffer stage (Γ—2 for double-buffering).
  • lds_anti: fragment_load β†’ global_load (LDS ANTI). Prevents a prefetch write from overwriting data a frag still needs. 2*g frags total minus 1 (last frag has no successor load), doubled for both channels.

Returns:

Int