Skip to main content

Mojo struct

MLAKVProducerPipeline

@register_passable(trivial) struct MLAKVProducerPipeline[dtype: DType, config: FA4Config]

Fields

  • kv_pipeline (StagedPipeline[config.num_kv_stages, config.num_qk_stages]):
  • smem (MLAKVProducerPipeline[dtype, config].SMemType):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

k_bytes

comptime k_bytes = (MLAKVProducerPipeline[dtype, config].k_elements * size_of[dtype]())

k_elements

comptime k_elements = MLAKVProducerPipeline[dtype, config].k_tma_layout.size()

k_layout

comptime k_layout = tile_layout_k_major[dtype, config.BN, 128, config.swizzle_mode]()

k_rope_layout

comptime k_rope_layout = tile_layout_k_major[dtype, config.BN, 64, config.swizzle_mode]()

k_tma_layout

comptime k_tma_layout = tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]()

KPairType

comptime KPairType = TMADestination[dtype, MLAKVProducerPipeline[dtype, config].k_tma_layout]

KType

comptime KType = LayoutTensor[dtype, MLAKVProducerPipeline[dtype, config].k_tma_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]

SMemType

comptime SMemType = UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

v_bytes

comptime v_bytes = (MLAKVProducerPipeline[dtype, config].v_elements * size_of[dtype]())

v_elements

comptime v_elements = MLAKVProducerPipeline[dtype, config].v_tma_layout.size()

v_tma_layout

comptime v_tma_layout = tile_layout_mn_major[dtype, 128, config.BK1, config.swizzle_mode]()

VPairType

comptime VPairType = TMADestination[dtype, MLAKVProducerPipeline[dtype, config].v_tma_layout]

VType

comptime VType = LayoutTensor[dtype, MLAKVProducerPipeline[dtype, config].v_tma_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]

Methods

__init__

__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

__init__(kv_pipeline: StagedPipeline[config.num_kv_stages, config.num_qk_stages], smem: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self

get_kv_smem

get_kv_smem[*, qk_stage: Int](self) -> MLAKVProducerPipeline[dtype, config].SMemType

Returns:

MLAKVProducerPipeline

get_k

get_k[*, qk_stage: Int, expect: Bool = True](self) -> MLAKVProducerPipeline[dtype, config].KPairType

Returns:

MLAKVProducerPipeline

get_v

get_v[*, qk_stage: Int](self) -> MLAKVProducerPipeline[dtype, config].VPairType

Returns:

MLAKVProducerPipeline

acquire_kv

acquire_kv[*, qk_stage: Int = (config - 1)](self)

commit_kv_step

commit_kv_step(mut self)

Step the kv pipeline. The does not perform the commit on the mbars; that should be handled by the tma_op.async_copy.

Was this page helpful?