Mojo module
attention_utils
Shared SM100 attention primitives used by both MHA and MLA kernels.
This module contains generic SM100 (Blackwell) GPU primitives including:
- TMEM access helpers (TMemTile, STMatrixLayout)
- Pipeline synchronization (StagedPipeline, RolePipeline, etc.)
- FTZ arithmetic (add_ftz, sub_ftz, mul_ftz, etc.)
- Barrier helpers (FA4MiscMBars)
- MMA building blocks (bulk_mma, SM100TensorAccumulatorSS/TS)
- Masking utilities (apply_mask, apply_oob_mask)
comptime valuesβ
ConsumerPipelineβ
comptime ConsumerPipeline = RolePipeline[?, False, ?, ?, ?]
KConsumerPipelineβ
comptime KConsumerPipeline = TMAConsumerPipeline[?, ?]
KPipelineβ
comptime KPipeline = StagedPipeline[?, ?]
KProducerPipelineβ
comptime KProducerPipeline = TMAProducerPipeline[?, ?]
KVPipelineβ
comptime KVPipeline = StagedPipeline[?, ?]
LocalLTβ
comptime LocalLT[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(1), IntTuple(1))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]
Parametersβ
LocalTensorβ
comptime LocalTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Parametersβ
- βdtype (
DType): - βlayout (
Layout[shape_types, stride_types]):
MBarTypeβ
comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]
ProducerPipelineβ
comptime ProducerPipeline = RolePipeline[?, producer_sub_stages=?, consumer_sub_stages=?, cta_group=?]
SharedMemPointerβ
comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]
Parametersβ
- βtype (
AnyType):
SharedMemTensorβ
comptime SharedMemTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.SHARED]
Parametersβ
- βdtype (
DType): - βlayout (
Layout[shape_types, stride_types]):
VConsumerPipelineβ
comptime VConsumerPipeline = TMAConsumerPipeline[?, ?, False]
VPipelineβ
comptime VPipeline = StagedPipeline[?]
VProducerPipelineβ
comptime VProducerPipeline = TMAProducerPipeline[?, ?, False]
Structsβ
- β
FA4MiscMBars: Manages all mbarrier resources for FA4. - β
MBarPipeline: - β
RolePipeline: Unified producer/consumer pipeline for barrier synchronization. - β
SM100TensorAccumulatorSS: - β
SM100TensorAccumulatorTS: - β
StagedPipeline: Unified pipeline for K, V, and KV tile barrier management. - β
STMatrixLayout: Layout for usingst_matrixfor writing the final accumulator to smem. - β
STMatrixOffsets: - β
TMAConsumerPipeline: Unified consumer pipeline for K and V TMA consumption. - β
TMADestination: Pairs a shared memory TileTensor with a barrier for TMA operations. - β
TMAProducerPipeline: Unified producer pipeline for K and V TMA loading. - β
TMemTile:
Functionsβ
- β
add_ftz: - β
add_ftz_rm: - β
apply_mask: - β
apply_oob_mask: - β
break_into_powers_of_two: - β
build_mma_ss: - β
build_mma_ts: - β
bulk_mma: - β
cumulative_power_of_two: - β
elect: - β
elect_mma_arrive: Arrive at the mbar pointer for the MMA instruction. - β
exp2_emulation: - β
expect_bytes_pred: Issuembarrier.arrive.expect_tx.shared::cta.b64predicated onpred != 0. - β
extract_power_of_two: - β
fma_ftz: - β
intrin: - β
intrin_ftz: - β
intrin_ftz_x2: - β
llvm_opaque_tid: - β
max_ftz: - β
maximum: - β
mul_ftz: - β
peel_mask: Determine which mask strategy applies to the peeled first iteration. - β
sub_ftz: - β
sum:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!