Skip to main content
Log in

Mojo function

mha_splitk_reduce

mha_splitk_reduce[output_type: DType, depth: UInt, num_heads: UInt, num_threads: UInt, group: UInt = UInt(1), use_exp2: Bool = False](intermediate_ptr: UnsafePointer[SIMD[output_type, 1]], output_ptr: UnsafePointer[SIMD[output_type, 1]], exp_sum_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], qk_max_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], batch_size: Int, num_partitions: Int)

Was this page helpful?