Mojo function
mla_combine_kernel
mla_combine_kernel[output_type: DType, accum_type: DType, head_dim: Int, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2](params: CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!