IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

attention_utils

Shared SM100 attention primitives used by both MHA and MLA kernels.

This module contains generic SM100 (Blackwell) GPU primitives including:

  • TMEM access helpers (TMemTile, STMatrixLayout)
  • Pipeline synchronization (StagedPipeline, RolePipeline, etc.)
  • FTZ arithmetic (add_ftz, sub_ftz, mul_ftz, etc.)
  • Barrier helpers (FA4MiscMBars)
  • MMA building blocks (bulk_mma, SM100TensorAccumulator)
  • Masking utilities (apply_mask, apply_oob_mask)

comptime values​

ConsumerPipeline​

comptime ConsumerPipeline = RolePipeline[_, False, _, _, _]

FP32_EXP_BIAS​

comptime FP32_EXP_BIAS = 127

KConsumerPipeline​

comptime KConsumerPipeline = TMAConsumerPipeline[_, _]

KPipeline​

comptime KPipeline = StagedPipeline[_, _]

KProducerPipeline​

comptime KProducerPipeline = TMAProducerPipeline[_, _]

KVPipeline​

comptime KVPipeline = StagedPipeline[_, _]

LocalLT​

comptime LocalLT[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(Int(1)), IntTuple(Int(1)))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]

Parameters​

LocalTensor​

comptime LocalTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Parameters​

MBarType​

comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]

ProducerPipeline​

comptime ProducerPipeline = RolePipeline[_, producer_sub_stages=_, consumer_sub_stages=_, cta_group=_]

SharedMemPointer​

comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]

Parameters​

SharedMemTensor​

comptime SharedMemTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutUntrackedOrigin, address_space=AddressSpace.SHARED]

Parameters​

VConsumerPipeline​

comptime VConsumerPipeline = TMAConsumerPipeline[_, _, False]

VPipeline​

comptime VPipeline = StagedPipeline[_]

VProducerPipeline​

comptime VProducerPipeline = TMAProducerPipeline[_, _, False]

Structs​

Functions​