Mojo function
mha_splitk_reduce
mha_splitk_reduce[output_type: DType, depth: Scalar[DType.uindex], num_heads: Scalar[DType.uindex], num_threads: Scalar[DType.uindex], group: Scalar[DType.uindex] = 1, use_exp2: Bool = False](intermediate_ptr: LegacyUnsafePointer[Scalar[output_type]], output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[output_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[output_type]()]], batch_size: Int, num_partitions: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!