Mojo module
mha_sm100_2q
comptime values
ConsumerPipeline
comptime ConsumerPipeline = RolePipeline[?, False, ?, ?]
EnableForcedOrdering
comptime EnableForcedOrdering = env_get_bool["FA4ForcedSoftmaxOrdering", False]()
KConsumerPipeline
comptime KConsumerPipeline = TMAConsumerPipeline[?, ?]
KPipeline
comptime KPipeline = StagedPipeline[?, ?]
KProducerPipeline
comptime KProducerPipeline = TMAProducerPipeline[?, ?]
KVPipeline
comptime KVPipeline = StagedPipeline[?, ?]
LocalTensor
comptime LocalTensor[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(1), IntTuple(1))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]
Parameters
logger
comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)
MBarType
comptime MBarType = UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED]
ProducerPipeline
comptime ProducerPipeline = RolePipeline[?, producer_sub_stages=?, consumer_sub_stages=?]
SharedMemPointer
comptime SharedMemPointer[type: AnyType] = UnsafePointer[type, MutAnyOrigin, address_space=AddressSpace.SHARED]
Parameters
- type (
AnyType):
SharedMemTensor
comptime SharedMemTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]
Parameters
VConsumerPipeline
comptime VConsumerPipeline = TMAConsumerPipeline[?, ?, False]
VPipeline
comptime VPipeline = StagedPipeline[?]
VProducerPipeline
comptime VProducerPipeline = TMAProducerPipeline[?, ?, False]
Structs
-
FA4Config: -
FA4MiscMBars: Manages all mbarrier resources for FA4. -
MBarPipeline: -
RolePipeline: Unified producer/consumer pipeline for barrier synchronization. -
SM100MHA2Q: -
SM100TensorAccumulatorSS: -
SM100TensorAccumulatorTS: -
StagedPipeline: Unified pipeline for K, V, and KV tile barrier management. -
STMatrixLayout: Layout for usingst_matrixfor writing the final accumulator to smem. -
STMatrixOffsets: -
TMAConsumerPipeline: Unified consumer pipeline for K and V TMA consumption. -
TMADestination: -
TMAProducerPipeline: Unified producer pipeline for K and V TMA loading. -
TMemTile:
Functions
-
add_ftz: -
add_ftz_rm: -
apply_mask: -
apply_oob_mask: -
break_into_powers_of_two: -
build_mma_ss: -
build_mma_ts: -
bulk_mma: -
cumulative_power_of_two: -
elect: -
elect_mma_arrive: Arrive at the mbar pointer for the MMA instruction. -
exp2_emulation: -
extract_power_of_two: -
fma_ftz: -
intrin: -
intrin_ftz: -
intrin_ftz_x2: -
llvm_opaque_tid: -
max3: -
max_ftz: -
maximum: -
mha_sm100_dispatch: -
mul_ftz: -
sub_ftz: -
sum:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!