Skip to main content

Mojo module

attention_utils

Shared SM100 attention primitives used by both MHA and MLA kernels.

This module contains generic SM100 (Blackwell) GPU primitives including:

  • TMEM access helpers (TMemTile, STMatrixLayout)
  • Pipeline synchronization (StagedPipeline, RolePipeline, etc.)
  • FTZ arithmetic (add_ftz, sub_ftz, mul_ftz, etc.)
  • Barrier helpers (FA4MiscMBars)
  • MMA building blocks (bulk_mma, SM100TensorAccumulatorSS/TS)
  • Masking utilities (apply_mask, apply_oob_mask)

comptime values​

ConsumerPipeline​

comptime ConsumerPipeline = RolePipeline[?, False, ?, ?, ?]

KConsumerPipeline​

comptime KConsumerPipeline = TMAConsumerPipeline[?, ?]

KPipeline​

comptime KPipeline = StagedPipeline[?, ?]

KProducerPipeline​

comptime KProducerPipeline = TMAProducerPipeline[?, ?]

KVPipeline​

comptime KVPipeline = StagedPipeline[?, ?]

LocalLT​

comptime LocalLT[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(1), IntTuple(1))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]

Parameters​

LocalTensor​

comptime LocalTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Parameters​

MBarType​

comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]

ProducerPipeline​

comptime ProducerPipeline = RolePipeline[?, producer_sub_stages=?, consumer_sub_stages=?, cta_group=?]

SharedMemPointer​

comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]

Parameters​

SharedMemTensor​

comptime SharedMemTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.SHARED]

Parameters​

VConsumerPipeline​

comptime VConsumerPipeline = TMAConsumerPipeline[?, ?, False]

VPipeline​

comptime VPipeline = StagedPipeline[?]

VProducerPipeline​

comptime VProducerPipeline = TMAProducerPipeline[?, ?, False]

Structs​

Functions​