Mojo module
sm100_attention_utils
Shared SM100 attention primitives used by both MHA and MLA kernels.
This module contains generic SM100 (Blackwell) GPU primitives including:
- TMEM access helpers (TMemTile, STMatrixLayout)
- Pipeline synchronization (StagedPipeline, RolePipeline, etc.)
- FTZ arithmetic (add_ftz, sub_ftz, mul_ftz, etc.)
- Barrier helpers (FA4MiscMBars)
- MMA building blocks (bulk_mma, SM100TensorAccumulatorSS/TS)
- Masking utilities (apply_mask, apply_oob_mask)
comptime values
ConsumerPipeline
comptime ConsumerPipeline = RolePipeline[?, False, ?, ?]
KConsumerPipeline
comptime KConsumerPipeline = TMAConsumerPipeline[?, ?]
KPipeline
comptime KPipeline = StagedPipeline[?, ?]
KProducerPipeline
comptime KProducerPipeline = TMAProducerPipeline[?, ?]
KVPipeline
comptime KVPipeline = StagedPipeline[?, ?]
LocalLT
comptime LocalLT[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(1), IntTuple(1))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]
Parameters
LocalTensor
comptime LocalTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Parameters
MBarType
comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]
ProducerPipeline
comptime ProducerPipeline = RolePipeline[?, producer_sub_stages=?, consumer_sub_stages=?]
SharedMemLT
comptime SharedMemLT[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]
Parameters
SharedMemPointer
comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]
Parameters
- type (
AnyType):
SharedMemTensor
comptime SharedMemTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.SHARED]
Parameters
VConsumerPipeline
comptime VConsumerPipeline = TMAConsumerPipeline[?, ?, False]
VPipeline
comptime VPipeline = StagedPipeline[?]
VProducerPipeline
comptime VProducerPipeline = TMAProducerPipeline[?, ?, False]
Structs
-
FA4MiscMBars: Manages all mbarrier resources for FA4. -
MBarPipeline: -
RolePipeline: Unified producer/consumer pipeline for barrier synchronization. -
SM100TensorAccumulatorSS: -
SM100TensorAccumulatorTS: -
StagedPipeline: Unified pipeline for K, V, and KV tile barrier management. -
STMatrixLayout: Layout for usingst_matrixfor writing the final accumulator to smem. -
STMatrixOffsets: -
TMAConsumerPipeline: Unified consumer pipeline for K and V TMA consumption. -
TMADestination: -
TMAProducerPipeline: Unified producer pipeline for K and V TMA loading. -
TMemTile:
Functions
-
add_ftz: -
add_ftz_rm: -
apply_mask: -
apply_oob_mask: -
break_into_powers_of_two: -
build_mma_ss: -
build_mma_ts: -
bulk_mma: -
cumulative_power_of_two: -
elect: -
elect_mma_arrive: Arrive at the mbar pointer for the MMA instruction. -
exp2_emulation: -
extract_power_of_two: -
fma_ftz: -
intrin: -
intrin_ftz: -
intrin_ftz_x2: -
llvm_opaque_tid: -
max_ftz: -
maximum: -
mul_ftz: -
sub_ftz: -
sum:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!