Mojo function
produce
produce[qkv_type: DType, BM: Int, BN: Int, q_smem_layout: Layout, q_desc_layout: Layout, k_smem_layout: Layout, k_desc_layout: Layout, v_smem_layout: Layout, v_desc_layout: Layout, depth: Int, padded_depth: Int, num_heads: Int, group: Int, PartitionType: MHAPartitionScheme, MaxSeqLenType: OptionallyStaticInt, SchedulerType: MHATileScheduler, KVLUTType: MHAOperand, MaskType: MHAMask, KVInputRowOffsetsType: OptionalPointer, ValidLengthType: OptionalPointer, //, swizzle_mode: TensorMapSwizzle, *, pipeline_stages: Int, ragged: Bool, _is_cache_length_accurate: Bool](q_tma_op: TMATensorTile[qkv_type, q_smem_layout, q_desc_layout], k_tma_op: TMATensorTile[qkv_type, k_smem_layout, k_desc_layout], v_tma_op: TMATensorTile[qkv_type, v_smem_layout, v_desc_layout], q_smem: LegacyUnsafePointer[Scalar[qkv_type], address_space=AddressSpace.SHARED], kv_smem: LegacyUnsafePointer[Scalar[qkv_type], address_space=AddressSpace.SHARED], produced_mbar_kv: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], consumed_mbar_kv: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], produced_mbar_q: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], consumed_mbar_q: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], kv_lut: KVLUTType, initial_position: MHAPosition[BM, BN, depth, padded_depth, num_heads, group, _is_decoding[MaxSeqLenType]()], partition: PartitionType, scheduler: SchedulerType, mask: MaskType, tile_summary: MHATileSummary[ValidLengthType], tile_state_arg: MHATileState, max_seq_len: MaxSeqLenType, num_keys_arg: UInt32, kv_input_row_offsets: KVInputRowOffsetsType)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!