Skip to main content

Mojo function

TopKMaskLogitsKernel

TopKMaskLogitsKernel[block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, logits_layout: Layout, masked_logits_layout: Layout](logits: LayoutTensor[dtype, logits_layout, MutableAnyOrigin], masked_logits: LayoutTensor[dtype, masked_logits_layout, MutableAnyOrigin], top_k_arr: UnsafePointer[Scalar[out_idx_type]], top_k_val: Int, d: Int)

Was this page helpful?