Mojo function
launch_mla_combine_kernel
launch_mla_combine_kernel[output_type: DType, accum_type: DType, head_dim: Int, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2](out_accum_split: TileTensor[output_type, out_accum_split.LayoutType, out_accum_split.origin, linear_idx_type=out_accum_split.linear_idx_type, element_size=out_accum_split.element_size], lse_accum_split: TileTensor[accum_type, lse_accum_split.LayoutType, lse_accum_split.origin, linear_idx_type=lse_accum_split.linear_idx_type, element_size=lse_accum_split.element_size], output: TileTensor[output_type, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], batch_size: Int, seq_len: Int, num_heads: Int, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!