Mojo function
fused_token_sampling_cpu
fused_token_sampling_cpu[dtype: DType, out_idx_type: DType, KLayoutType: TensorLayout = Layout[*?, *?], TemperatureLayoutType: TensorLayout = Layout[*?, *?], TopPLayoutType: TensorLayout = Layout[*?, *?], SeedLayoutType: TensorLayout = Layout[*?, *?]](max_k: Int, input: TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], out_idxs: TileTensor[out_idx_type, address_space=out_idxs.address_space, linear_idx_type=out_idxs.linear_idx_type, element_size=out_idxs.element_size], k: Optional[TileTensor[DType.int64, KLayoutType, ImmutAnyOrigin]] = None, temperature: Optional[TileTensor[DType.float32, TemperatureLayoutType, ImmutAnyOrigin]] = None, top_p: Optional[TileTensor[DType.float32, TopPLayoutType, ImmutAnyOrigin]] = None, seed: Optional[TileTensor[DType.uint64, SeedLayoutType, ImmutAnyOrigin]] = None)
Generalized implementation of the Top K algorithm with sampling. Returns the sampled index from the innermost dimension of the input tensor for each row/subvolume.
Parameters:
- βdtype (
DType): Data type of the input buffer. - βout_idx_type (
DType): Data type of the output indices. - βKLayoutType (
TensorLayout): Layout type of the k buffer. - βTemperatureLayoutType (
TensorLayout): Layout type of the temperature buffer. - βTopPLayoutType (
TensorLayout): Layout type of the top_p buffer. - βSeedLayoutType (
TensorLayout): Layout type of the seed buffer.
Args:
- βmax_k (
Int): Largest number of top elements. - βinput (
TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size]): TileTensor[dtype] (Any shape)- The input tensor. - βout_idxs (
TileTensor[out_idx_type, address_space=out_idxs.address_space, linear_idx_type=out_idxs.linear_idx_type, element_size=out_idxs.element_size]): TileTensor[out_idx_type] (shape of [input_shape[:-1]] + [1]) - The output indices. - βk (
Optional[TileTensor[DType.int64, KLayoutType, ImmutAnyOrigin]]): Optional device buffer of top elements to keep for each batch element. - βtemperature (
Optional[TileTensor[DType.float32, TemperatureLayoutType, ImmutAnyOrigin]]): The temperature based scaling. - βtop_p (
Optional[TileTensor[DType.float32, TopPLayoutType, ImmutAnyOrigin]]): Only use the tokens whose cumulative probability exceeds this threshold. - βseed (
Optional[TileTensor[DType.uint64, SeedLayoutType, ImmutAnyOrigin]]): The seed to use for the random number generator.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!