Skip to main content

Mojo function

mla_splitk_reduce

mla_splitk_reduce[intermediate_type: DType, output_type: DType, depth: Int, num_heads: Int, D_TILES: Int, W_PARTS: Int, MAX_PARTITIONS: Int, use_exp2: Bool = False](intermediate_ptr: UnsafePointer[Scalar[intermediate_type], ImmutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], exp_sum_ptr: UnsafePointer[Scalar[get_accum_type[output_type]()], MutAnyOrigin], qk_max_ptr: UnsafePointer[Scalar[get_accum_type[output_type]()], MutAnyOrigin], batch_size: Int, num_partitions: Int)