IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

apply_penalties_to_logits

def apply_penalties_to_logits[logit_type: DType, penalty_type: DType, //, target: StringSlice[StaticConstantOrigin]](logits: TileTensor[logit_type, address_space=logits.address_space, linear_idx_type=logits.linear_idx_type, element_size=logits.element_size], compressed_frequency_data: TileTensor[DType.int32, address_space=compressed_frequency_data.address_space, linear_idx_type=compressed_frequency_data.linear_idx_type, element_size=compressed_frequency_data.element_size], frequency_offsets: TileTensor[DType.uint32, address_space=frequency_offsets.address_space, linear_idx_type=frequency_offsets.linear_idx_type, element_size=frequency_offsets.element_size], frequency_penalty: TileTensor[penalty_type, address_space=frequency_penalty.address_space, linear_idx_type=frequency_penalty.linear_idx_type, element_size=frequency_penalty.element_size], presence_penalty: TileTensor[penalty_type, address_space=presence_penalty.address_space, linear_idx_type=presence_penalty.linear_idx_type, element_size=presence_penalty.element_size], repetition_penalty: TileTensor[penalty_type, address_space=repetition_penalty.address_space, linear_idx_type=repetition_penalty.linear_idx_type, element_size=repetition_penalty.element_size], ctx: DeviceContext)

Apply penalties to the logits based on the frequency of the tokens in the batch.

The frequency data is stored in a CSR format, where the frequency_offsets is the starting index of each sequence in the frequency_data array. The frequency_data array is a 2D array, where:

  • frequency_data[i, 0] is the token id
  • frequency_data[i, 1] is the frequency of the token in the sequence