Mojo function
top_k
top_k[dtype: DType, out_idx_type: DType, //, largest: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_k: Int, axis: Int, out_vals: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_idxs: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sorted: Bool, ctx: DeviceContextPtr, k: OptionalReg[LayoutTensor[DType.int64, Layout.row_major(-1), MutableAnyOrigin]] = OptionalReg[LayoutTensor[DType.int64, Layout.row_major(-1), MutableAnyOrigin]]({:i1 0, 1}))
Implementation of the Top K algorithm. Returns the top or bottom K elements and their index along a specified axis.
Parameters:
- dtype (DType): Data type of the input buffer.
- out_idx_type (DType): The data dtype of the output indices (default is DType.int64).
- largest (Bool): Whether to find the maximum (top k) or minimum value (bottom k).
- target (StringSlice): The target to run on.
Args:
- input (LayoutTensor): The input tensor.
- max_k (Int): The largest number of top elements.
- axis (Int): The axis along which to operate.
- out_vals (LayoutTensor): Output values.
- out_idxs (LayoutTensor): Output indices.
- sorted (Bool): Indicates if the top/bottom K elements are in (stable) sorted order.
- ctx (DeviceContextPtr): The device call context.
- k (OptionalReg): Per batch element k value.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!
