For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
causal_conv1d_varlen_states_cpu
causal_conv1d_varlen_states_cpu[x_dtype: DType, cu_seqlens_dtype: DType, states_dtype: DType](total_tokens: Int, dim: Int, batch: Int, state_len: Int, x: TileTensor[x_dtype, address_space=x.address_space, linear_idx_type=x.linear_idx_type, element_size=x.element_size], cu_seqlens: TileTensor[cu_seqlens_dtype, address_space=cu_seqlens.address_space, linear_idx_type=cu_seqlens.linear_idx_type, element_size=cu_seqlens.element_size], states: TileTensor[states_dtype, address_space=states.address_space, linear_idx_type=states.linear_idx_type, element_size=states.element_size], x_seqlen_stride: UInt32, x_dim_stride: UInt32, states_batch_stride: UInt32, states_dim_stride: UInt32, states_seqlen_stride: UInt32)
Extract the last state_len elements from each variable length sequence.
For each sequence in the batch, copies the last state_len tokens (or fewer if the sequence is shorter) to the states tensor. If a sequence is shorter than state_len, the earlier positions in states are zero-padded.
This is the CPU reference implementation for causal_conv1d_varlen_states.
Parameters:
- βx_dtype (
DType): Data type of the input tensor. - βcu_seqlens_dtype (
DType): Data type of the cumulative sequence lengths. - βstates_dtype (
DType): Data type of the output states tensor.
Args:
- βtotal_tokens (
Int): Total number of tokens across all sequences. - βdim (
Int): Number of channels/dimensions. - βbatch (
Int): Number of sequences. - βstate_len (
Int): Number of elements to extract per sequence (typically width - 1). - βx (
TileTensor[x_dtype, address_space=x.address_space, linear_idx_type=x.linear_idx_type, element_size=x.element_size]): Input tensor of shape (total_tokens, dim). - βcu_seqlens (
TileTensor[cu_seqlens_dtype, address_space=cu_seqlens.address_space, linear_idx_type=cu_seqlens.linear_idx_type, element_size=cu_seqlens.element_size]): Cumulative sequence lengths of shape (batch + 1,). - βstates (
TileTensor[states_dtype, address_space=states.address_space, linear_idx_type=states.linear_idx_type, element_size=states.element_size]): Output states tensor of shape (batch, dim, state_len). - βx_seqlen_stride (
UInt32): Stride for sequence dimension in x. - βx_dim_stride (
UInt32): Stride for dimension in x. - βstates_batch_stride (
UInt32): Stride for batch dimension in states. - βstates_dim_stride (
UInt32): Stride for dimension in states. - βstates_seqlen_stride (
UInt32): Stride for sequence dimension in states.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!