Skip to main content

Mojo function

mla_combine_kernel

mla_combine_kernel[output_type: DType, accum_type: DType, head_dim: Int, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2](params: CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head])

Was this page helpful?