Mojo struct
MLAKVProducerPipeline
@register_passable(trivial)
struct MLAKVProducerPipeline[k_nope_dtype: DType, k_rope_dtype: DType, config: FA4Config]
Fields
- kv_pipeline (
StagedPipeline[config.num_kv_stages, config.num_qk_stages]): - smem (
MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].SMemType):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
k_bytes
comptime k_bytes = (MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_nope_bytes + MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_rope_bytes)
k_elements
comptime k_elements = (MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_nope_elements + MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_rope_elements)
k_nope_bytes
comptime k_nope_bytes = (MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_nope_elements * size_of[k_nope_dtype]())
k_nope_elements
comptime k_nope_elements = MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_nope_tma_layout.size()
k_nope_tma_layout
comptime k_nope_tma_layout = tile_layout_k_major[k_nope_dtype, config.BN, 128, config.swizzle_mode]()
k_rope_bytes
comptime k_rope_bytes = (MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_rope_elements * size_of[k_rope_dtype]())
k_rope_elements
comptime k_rope_elements = MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_rope_tma_layout.size()
k_rope_swizzle_mode
comptime k_rope_swizzle_mode = TensorMapSwizzle(2) if (k_rope_dtype == DType.float8_e4m3fn)._mlir_value else TensorMapSwizzle.SWIZZLE_128B
k_rope_tma_layout
comptime k_rope_tma_layout = tile_layout_k_major[k_rope_dtype, config.BN, 64, MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_rope_swizzle_mode]()
k_tma_layout
comptime k_tma_layout = tile_layout_k_major[k_nope_dtype, config.BN, config.BK0, config.swizzle_mode]()
KPairType
comptime KPairType = TMADestination[k_nope_dtype, MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_tma_layout]
KType
comptime KType = LayoutTensor[k_nope_dtype, MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].k_tma_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]
SMemType
comptime SMemType = UnsafePointer[Scalar[k_nope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
v_bytes
comptime v_bytes = (MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].v_elements * size_of[k_nope_dtype]())
v_elements
comptime v_elements = MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].v_tma_layout.size()
v_tma_layout
comptime v_tma_layout = tile_layout_mn_major[k_nope_dtype, 128, config.BK1, config.swizzle_mode]()
VPairType
comptime VPairType = TMADestination[k_nope_dtype, MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].v_tma_layout]
VType
comptime VType = LayoutTensor[k_nope_dtype, MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].v_tma_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]
Methods
__init__
__init__(mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], smem: UnsafePointer[Scalar[k_nope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
__init__(kv_pipeline: StagedPipeline[config.num_kv_stages, config.num_qk_stages], smem: UnsafePointer[Scalar[k_nope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]) -> Self
get_kv_smem
get_kv_smem[*, qk_stage: Int](self) -> MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].SMemType
Returns:
MLAKVProducerPipeline
get_k
get_k[*, qk_stage: Int, expect: Bool = True](self) -> MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].KPairType
Returns:
MLAKVProducerPipeline
get_v
get_v[*, qk_stage: Int](self) -> MLAKVProducerPipeline[k_nope_dtype, k_rope_dtype, config].VPairType
Returns:
MLAKVProducerPipeline
acquire_kv
acquire_kv[*, qk_stage: Int = (config - 1)](self)
commit_kv_step
commit_kv_step(mut self)
Step the kv pipeline. The does not perform the commit on the mbars; that should be handled by the tma_op.async_copy.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!