Skip to main content

Mojo function

top_k

top_k[dtype: DType, out_idx_type: DType, //, largest: Bool = True, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](input: TileTensor[dtype, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], max_k: Int, axis: Int, out_vals: TileTensor[dtype, out_vals.LayoutType, out_vals.origin, address_space=out_vals.address_space, linear_idx_type=out_vals.linear_idx_type, element_size=out_vals.element_size], out_idxs: TileTensor[out_idx_type, out_idxs.LayoutType, out_idxs.origin, address_space=out_idxs.address_space, linear_idx_type=out_idxs.linear_idx_type, element_size=out_idxs.element_size], sorted: Bool, ctx: DeviceContextPtr, k: Optional[TileTensor[DType.int64, Layout[RuntimeInt[DType.int64], ComptimeInt[1]], ImmutAnyOrigin]] = None)

Implementation of the Top K algorithm. Returns the top or bottom K elements and their index along a specified axis.

Parameters:

  • dtype (DType): Data type of the input buffer.
  • out_idx_type (DType): The data dtype of the output indices (default == DType.int64).
  • largest (Bool): Whether to find the maximum (top k) or minimum value (bottom k).
  • target (StringSlice): The target to run on.

Args:

  • input (TileTensor): The input tensor.
  • max_k (Int): The largest number of top elements.
  • axis (Int): The axis along which to operate.
  • out_vals (TileTensor): Output values.
  • out_idxs (TileTensor): Output indices.
  • sorted (Bool): Indicates if the top/bottom K elements are in (stable) sorted order.
  • ctx (DeviceContextPtr): The device call context.
  • k (Optional): Per batch element k value.

Was this page helpful?