Skip to main content

Mojo module

mha_sm100_2q

comptime values

ConsumerPipeline

comptime ConsumerPipeline = RolePipeline[?, False, ?, ?]

EnableForcedOrdering

comptime EnableForcedOrdering = env_get_bool["FA4ForcedSoftmaxOrdering", False]()

KConsumerPipeline

comptime KConsumerPipeline = TMAConsumerPipeline[?, ?]

KPipeline

comptime KPipeline = StagedPipeline[?, ?]

KProducerPipeline

comptime KProducerPipeline = TMAProducerPipeline[?, ?]

KVPipeline

comptime KVPipeline = StagedPipeline[?, ?]

LocalTensor

comptime LocalTensor[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(1), IntTuple(1))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]

Parameters

logger

comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)

MBarType

comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]

ProducerPipeline

comptime ProducerPipeline = RolePipeline[?, producer_sub_stages=?, consumer_sub_stages=?]

SharedMemPointer

comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]

Parameters

SharedMemTensor

comptime SharedMemTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]

Parameters

VConsumerPipeline

comptime VConsumerPipeline = TMAConsumerPipeline[?, ?, False]

VPipeline

comptime VPipeline = StagedPipeline[?]

VProducerPipeline

comptime VProducerPipeline = TMAProducerPipeline[?, ?, False]

Structs

Functions

Was this page helpful?