Skip to main content

Mojo function

produce

produce[qkv_type: DType, BM: Int, BN: Int, depth: Int, num_heads: Int, group: Int, PartitionType: MHAPartitionScheme, swizzle_mode: TensorMapSwizzle, q_tma_rows: Int, q_tma_cols: Int, MaxSeqLenType: OptionallyStaticInt, SchedulerType: MHATileScheduler, KVLUTType: MHAOperand, MaskType: MHAMask, //, *, pipeline_stages: Int, ragged: Bool, _is_cache_length_accurate: Bool](q_tma_op: TMATensorTile[qkv_type, tile_layout_k_major[::DType,::Int,::Int,::TensorMapSwizzle](), _tma_desc_tile_layout[::DType,::Int,::IndexList[$1, ::DType()], k_tma_op: TMATensorTile[qkv_type, tile_layout_k_major[::DType,::Int,::Int,::TensorMapSwizzle](), _tma_desc_tile_layout[::DType,::Int,::IndexList[$1, ::DType()], v_tma_op: TMATensorTile[qkv_type, tile_layout_mn_major[::DType,::Int,::Int,::TensorMapSwizzle](), _tma_desc_tile_layout[::DType,::Int,::IndexList[$1, ::DType(), False], q_smem: UnsafePointer[SIMD[qkv_type, 1], address_space=AddressSpace(3), alignment=128], kv_smem: UnsafePointer[SIMD[qkv_type, 1], address_space=AddressSpace(3), alignment=128], produced_mbar_kv: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], consumed_mbar_kv: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], produced_mbar_q: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], consumed_mbar_q: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], kv_lut: KVLUTType, initial_position: MHAPosition[BM, BN, depth, num_heads, group, _is_decoding[nn::mha_utils::OptionallyStaticInt]()], partition: PartitionType, scheduler: SchedulerType, mask: MaskType, tile_summary: MHATileSummary, tile_state_arg: MHATileState, max_seq_len: MaxSeqLenType, num_keys_arg: SIMD[uint32, 1], kv_input_row_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]])

Was this page helpful?