Skip to main content

Mojo module

sm100_attention_utils

Shared SM100 attention primitives used by both MHA and MLA kernels.

This module contains generic SM100 (Blackwell) GPU primitives including:

  • TMEM access helpers (TMemTile, STMatrixLayout)
  • Pipeline synchronization (StagedPipeline, RolePipeline, etc.)
  • FTZ arithmetic (add_ftz, sub_ftz, mul_ftz, etc.)
  • Barrier helpers (FA4MiscMBars)
  • MMA building blocks (bulk_mma, SM100TensorAccumulatorSS/TS)
  • Masking utilities (apply_mask, apply_oob_mask)

comptime values

ConsumerPipeline

comptime ConsumerPipeline = RolePipeline[?, False, ?, ?]

KConsumerPipeline

comptime KConsumerPipeline = TMAConsumerPipeline[?, ?]

KPipeline

comptime KPipeline = StagedPipeline[?, ?]

KProducerPipeline

comptime KProducerPipeline = TMAProducerPipeline[?, ?]

KVPipeline

comptime KVPipeline = StagedPipeline[?, ?]

LocalLT

comptime LocalLT[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(1), IntTuple(1))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]

Parameters

LocalTensor

comptime LocalTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Parameters

MBarType

comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]

ProducerPipeline

comptime ProducerPipeline = RolePipeline[?, producer_sub_stages=?, consumer_sub_stages=?]

SharedMemLT

comptime SharedMemLT[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]

Parameters

SharedMemPointer

comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]

Parameters

SharedMemTensor

comptime SharedMemTensor[dtype: DType, layout: Layout[shape_types, stride_types]] = TileTensor[dtype, Layout[shape_types, stride_types], MutExternalOrigin, address_space=AddressSpace.SHARED]

Parameters

VConsumerPipeline

comptime VConsumerPipeline = TMAConsumerPipeline[?, ?, False]

VPipeline

comptime VPipeline = StagedPipeline[?]

VProducerPipeline

comptime VProducerPipeline = TMAProducerPipeline[?, ?, False]

Structs

Functions

Was this page helpful?