Skip to main content

Mojo function

topk_gpu

topk_gpu[dtype: DType, out_idx_type: DType, //, sampling: Bool = True, largest: Bool = True, _force_old_impl: Bool = False, KLayoutType: TensorLayout = Layout[*?, *?], TemperatureLayoutType: TensorLayout = Layout[*?, *?], TopPLayoutType: TensorLayout = Layout[*?, *?], MinPLayoutType: TensorLayout = Layout[*?, *?], SeedLayoutType: TensorLayout = Layout[*?, *?]](ctx: DeviceContext, max_k: Int, input: TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], out_vals: TileTensor[dtype, address_space=out_vals.address_space, linear_idx_type=out_vals.linear_idx_type, element_size=out_vals.element_size], out_idxs: TileTensor[out_idx_type, address_space=out_idxs.address_space, linear_idx_type=out_idxs.linear_idx_type, element_size=out_idxs.element_size], block_size: Optional[Int] = None, num_blocks_per_input: Optional[Int] = None, k: Optional[TileTensor[DType.int64, KLayoutType, ImmutAnyOrigin]] = None, temperature: Optional[TileTensor[DType.float32, TemperatureLayoutType, ImmutAnyOrigin]] = None, top_p: Optional[TileTensor[DType.float32, TopPLayoutType, ImmutAnyOrigin]] = None, min_p: Optional[TileTensor[DType.float32, MinPLayoutType, ImmutAnyOrigin]] = None, seed: Optional[TileTensor[DType.uint64, SeedLayoutType, ImmutAnyOrigin]] = None)

Generalized implementation of the Top K algorithm with/without sampling. Returns the sampled index from the innermost dimension of the input tensor for each row/subvolume or the top K values and indices across the tensor.

Parameters:

  • ​dtype (DType): DType - The data dtype of the input tensor.
  • ​out_idx_type (DType): DType - The data dtype of the output indices (default == DType.int).
  • ​sampling (Bool): Bool - Whether to return token samples from topK dist (default is True).
  • ​largest (Bool): Bool - Whether to find the maximum or minimum value.
  • ​_force_old_impl (Bool): Bool - Whether to force use the old implementation.
  • ​KLayoutType (TensorLayout): Layout type of the k buffer.
  • ​TemperatureLayoutType (TensorLayout): Layout type of the temperature buffer.
  • ​TopPLayoutType (TensorLayout): Layout type of the top_p buffer.
  • ​MinPLayoutType (TensorLayout): Layout type of the min_p buffer.
  • ​SeedLayoutType (TensorLayout): Layout type of the seed buffer.

Args: