For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo module
attention_utils
Shared SM100 attention primitives used by both MHA and MLA kernels.
This module contains generic SM100 (Blackwell) GPU primitives including:
- TMEM access helpers (TMemTile, STMatrixLayout)
- Pipeline synchronization (StagedPipeline, RolePipeline, etc.)
- FTZ arithmetic (add_ftz, sub_ftz, mul_ftz, etc.)
- Barrier helpers (FA4MiscMBars)
- MMA building blocks (bulk_mma, SM100TensorAccumulator)
- Masking utilities (apply_mask, apply_oob_mask)
comptime valuesβ
ConsumerPipelineβ
comptime ConsumerPipeline = RolePipeline[_, False, _, _, _]
FP32_EXP_BIASβ
comptime FP32_EXP_BIAS = 127
KConsumerPipelineβ
comptime KConsumerPipeline = TMAConsumerPipeline[_, _]
KPipelineβ
comptime KPipeline = StagedPipeline[_, _]
KProducerPipelineβ
comptime KProducerPipeline = TMAProducerPipeline[_, _]
KVPipelineβ
comptime KVPipeline = StagedPipeline[_, _]
LocalLTβ
comptime LocalLT[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(Int(1)), IntTuple(Int(1)))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]
Parametersβ
LocalTensorβ
comptime LocalTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Parametersβ
- βdtype (
DType): - βlayout (
Layout[shape_types, stride_types]):
MBarTypeβ
comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]
ProducerPipelineβ
comptime ProducerPipeline = RolePipeline[_, producer_sub_stages=_, consumer_sub_stages=_, cta_group=_]
SharedMemPointerβ
comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]
Parametersβ
- βtype (
AnyType):
SharedMemTensorβ
comptime SharedMemTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutUntrackedOrigin, address_space=AddressSpace.SHARED]
Parametersβ
- βdtype (
DType): - βlayout (
Layout[shape_types, stride_types]):
VConsumerPipelineβ
comptime VConsumerPipeline = TMAConsumerPipeline[_, _, False]
VPipelineβ
comptime VPipeline = StagedPipeline[_]
VProducerPipelineβ
comptime VProducerPipeline = TMAProducerPipeline[_, _, False]
Structsβ
- β
FA4MiscMBars: Manages all mbarrier resources for FA4. - β
MBarPipeline: - β
RolePipeline: Unified producer/consumer pipeline for barrier synchronization. - β
SM100TensorAccumulator: - β
StagedPipeline: Unified pipeline for K, V, and KV tile barrier management. - β
STMatrixLayout: Layout for usingst_matrixfor writing the final accumulator to smem. - β
STMatrixOffsets: - β
TMAConsumerPipeline: Unified consumer pipeline for K and V TMA consumption. - β
TMADestination: Pairs a shared memory TileTensor with a barrier for TMA operations. - β
TMAProducerPipeline: Unified producer pipeline for K and V TMA loading. - β
TMemTile:
Functionsβ
- β
add_ftz: - β
add_ftz_rm: - β
apply_mask: - β
apply_oob_mask: - β
break_into_powers_of_two: - β
bulk_mma: - β
bulk_mma_partial: - β
bulk_mma_ss_partial: - β
bulk_mma_ws: - β
bulk_mma_ws_partial: - β
bulk_mma_ws_ts: - β
bulk_mma_ws_ts_partial: - β
cumulative_power_of_two: - β
elect_mma_arrive: Arrive at the mbar pointer for the MMA instruction. - β
exp2_emulation: - β
expect_bytes_pred: Issuembarrier.arrive.expect_tx.shared::cta.b64predicated onpred != 0. - β
extract_power_of_two: - β
fma_ftz: - β
intrin: - β
intrin_ftz: - β
intrin_ftz_x2: - β
llvm_opaque_tid: - β
max_ftz: - β
maximum: - β
mul_ftz: - β
peel_mask: Determine which mask strategy applies to the peeled first iteration. - β
sub_ftz: - β
sum:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!