Mojo function
top_k
top_k[dtype: DType, out_idx_type: DType, //, largest: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](input: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], max_k: Int, axis: Int, out_vals: TileTensor[dtype, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], out_idxs: TileTensor[out_idx_type, LayoutType, origin, address_space=address_space, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], sorted: Bool, ctx: DeviceContextPtr, k: Optional[TileTensor[DType.int64, Layout[RuntimeInt[DType.int64], ComptimeInt[1]], ImmutAnyOrigin]] = None)
Implementation of the Top K algorithm. Returns the top or bottom K elements and their index along a specified axis.
Parameters:
- dtype (
DType): Data type of the input buffer. - out_idx_type (
DType): The data dtype of the output indices (default == DType.int64). - largest (
Bool): Whether to find the maximum (top k) or minimum value (bottom k). - target (
StringSlice): The target to run on.
Args:
- input (
TileTensor): The input tensor. - max_k (
Int): The largest number of top elements. - axis (
Int): The axis along which to operate. - out_vals (
TileTensor): Output values. - out_idxs (
TileTensor): Output indices. - sorted (
Bool): Indicates if the top/bottom K elements are in (stable) sorted order. - ctx (
DeviceContextPtr): The device call context. - k (
Optional): Per batch element k value.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!