Mojo function
mha
mha[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, valid_length_layout: Layout, config: MHAConfig[dtype], group: Int = 1, use_score_mod: Bool = False, ragged: Bool = False, is_shared_kv: Bool = False, sink: Bool = False, _use_valid_length: Bool = False, _is_cache_length_accurate: Bool = False, _padded_ndbuffer: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], scale: Float32, batch_size: Int, seq_len_arg: Int, num_keys_arg: Int, valid_length: LayoutTensor[DType.uint32, valid_length_layout, MutAnyOrigin], kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]], sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]], mask: mask_t, score_mod: score_mod_t)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!