Skip to main content
Log in

Mojo function

top_k

top_k[rank: Int, type: DType, out_idx_type: DType, //, largest: Bool = True, target: StringSlice[StaticConstantOrigin] = __init__[__mlir_type.!kgen.string]("cpu")](input: NDBuffer[type, rank, origin], k: Int, axis: Int, out_vals: NDBuffer[type, rank, origin], out_idxs: NDBuffer[out_idx_type, rank, origin], sorted: Bool, ctx: DeviceContextPtr)

Implementation of the Top K algorithm. Returns the top or bottom K elements and their index along a specified axis.

Parameters:

  • rank (Int): Rank of the input.
  • type (DType): Data type of the input buffer.
  • out_idx_type (DType): The data type of the output indices (default is DType.int64).
  • largest (Bool): Whether to find the maximum (top k) or minimum value (bottom k).
  • target (StringSlice[StaticConstantOrigin]): The target to run on.

Args:

  • input (NDBuffer[type, rank, origin]): The input tensor.
  • k (Int): Represents the K largest/smallest value.
  • axis (Int): On which axis it should operate.
  • out_vals (NDBuffer[type, rank, origin]): Output values.
  • out_idxs (NDBuffer[out_idx_type, rank, origin]): Output indices.
  • sorted (Bool): Indicates if the top/bottom K elements are in (stable) sorted order.
  • ctx (DeviceContextPtr): The device call context.

Was this page helpful?