Mojo function
mla_decode_combine_partial_outputs
mla_decode_combine_partial_outputs[output_type: DType, accum_type: DType, head_dim: Int, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2](out_accum_split: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lse_accum_split: LayoutTensor[accum_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], batch_size: Int, seq_len: Int, num_heads: Int, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!