IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

causal_conv1d_varlen_states_cpu

causal_conv1d_varlen_states_cpu[x_dtype: DType, cu_seqlens_dtype: DType, states_dtype: DType](total_tokens: Int, dim: Int, batch: Int, state_len: Int, x: TileTensor[x_dtype, address_space=x.address_space, linear_idx_type=x.linear_idx_type, element_size=x.element_size], cu_seqlens: TileTensor[cu_seqlens_dtype, address_space=cu_seqlens.address_space, linear_idx_type=cu_seqlens.linear_idx_type, element_size=cu_seqlens.element_size], states: TileTensor[states_dtype, address_space=states.address_space, linear_idx_type=states.linear_idx_type, element_size=states.element_size], x_seqlen_stride: UInt32, x_dim_stride: UInt32, states_batch_stride: UInt32, states_dim_stride: UInt32, states_seqlen_stride: UInt32)

Extract the last state_len elements from each variable length sequence.

For each sequence in the batch, copies the last state_len tokens (or fewer if the sequence is shorter) to the states tensor. If a sequence is shorter than state_len, the earlier positions in states are zero-padded.

This is the CPU reference implementation for causal_conv1d_varlen_states.

Parameters:

  • ​x_dtype (DType): Data type of the input tensor.
  • ​cu_seqlens_dtype (DType): Data type of the cumulative sequence lengths.
  • ​states_dtype (DType): Data type of the output states tensor.

Args: