Skip to main content

Mojo function

TopKSoftmaxSampleKernel

TopKSoftmaxSampleKernel[block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, logits_origin: ImmutOrigin, logits_shape_types: Variadic[CoordLike], logits_stride_types: Variadic[CoordLike], sampled_origin: MutOrigin, sampled_shape_types: Variadic[CoordLike], sampled_stride_types: Variadic[CoordLike]](logits: TileTensor[dtype, logits_origin], sampled_indices: TileTensor[out_idx_type, sampled_origin], top_k_arr: UnsafePointer[Scalar[out_idx_type], MutExternalOrigin], top_k_val: Int, temperature_val: Float32, temperature: UnsafePointer[Float32, MutExternalOrigin], seed_val: UInt64, seed: UnsafePointer[UInt64, MutExternalOrigin], d: Int)

Was this page helpful?