IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

causal_conv1d_varlen_states_gpu

causal_conv1d_varlen_states_gpu[x_dtype: DType, cu_seqlens_dtype: DType, states_dtype: DType, BLOCK_M: Int, BLOCK_N: Int, x_LT: TensorLayout, cu_seqlens_LT: TensorLayout, states_LT: TensorLayout](total_tokens: Int, dim: Int, batch: Int, state_len: Int, x: TileTensor[x_dtype, x_LT, MutExternalOrigin], cu_seqlens: TileTensor[cu_seqlens_dtype, cu_seqlens_LT, MutExternalOrigin], states: TileTensor[states_dtype, states_LT, MutExternalOrigin], x_seqlen_stride: UInt32, x_dim_stride: UInt32, states_batch_stride: UInt32, states_dim_stride: UInt32, states_seqlen_stride: UInt32)

GPU kernel for extracting states from variable length sequences.

Each thread block processes a tile of (BLOCK_M x BLOCK_N) elements. Grid dimensions: (ceildiv(dim, BLOCK_N), ceildiv(state_len, BLOCK_M), batch)

Parameters:

  • ​x_dtype (DType): Data type of input.
  • ​cu_seqlens_dtype (DType): Data type of cumulative sequence lengths.
  • ​states_dtype (DType): Data type of output states.
  • ​BLOCK_M (Int): Tile size for sequence dimension.
  • ​BLOCK_N (Int): Tile size for channel dimension.
  • ​x_LT (TensorLayout): Layout type of input tensor.
  • ​cu_seqlens_LT (TensorLayout): Layout type of cumulative sequence lengths tensor.
  • ​states_LT (TensorLayout): Layout type of output states tensor.

Args: