For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
causal_conv1d_varlen_states_gpu
causal_conv1d_varlen_states_gpu[x_dtype: DType, cu_seqlens_dtype: DType, states_dtype: DType, BLOCK_M: Int, BLOCK_N: Int, x_LT: TensorLayout, cu_seqlens_LT: TensorLayout, states_LT: TensorLayout](total_tokens: Int, dim: Int, batch: Int, state_len: Int, x: TileTensor[x_dtype, x_LT, MutExternalOrigin], cu_seqlens: TileTensor[cu_seqlens_dtype, cu_seqlens_LT, MutExternalOrigin], states: TileTensor[states_dtype, states_LT, MutExternalOrigin], x_seqlen_stride: UInt32, x_dim_stride: UInt32, states_batch_stride: UInt32, states_dim_stride: UInt32, states_seqlen_stride: UInt32)
GPU kernel for extracting states from variable length sequences.
Each thread block processes a tile of (BLOCK_M x BLOCK_N) elements. Grid dimensions: (ceildiv(dim, BLOCK_N), ceildiv(state_len, BLOCK_M), batch)
Parameters:
- βx_dtype (
DType): Data type of input. - βcu_seqlens_dtype (
DType): Data type of cumulative sequence lengths. - βstates_dtype (
DType): Data type of output states. - βBLOCK_M (
Int): Tile size for sequence dimension. - βBLOCK_N (
Int): Tile size for channel dimension. - βx_LT (
TensorLayout): Layout type of input tensor. - βcu_seqlens_LT (
TensorLayout): Layout type of cumulative sequence lengths tensor. - βstates_LT (
TensorLayout): Layout type of output states tensor.
Args:
- βtotal_tokens (
Int): Total number of tokens. - βdim (
Int): Number of channels. - βbatch (
Int): Number of sequences. - βstate_len (
Int): State length to extract. - βx (
TileTensor[x_dtype, x_LT, MutExternalOrigin]): Input tensor. - βcu_seqlens (
TileTensor[cu_seqlens_dtype, cu_seqlens_LT, MutExternalOrigin]): Cumulative sequence lengths. - βstates (
TileTensor[states_dtype, states_LT, MutExternalOrigin]): Output states tensor. - βx_seqlen_stride (
UInt32): Stride for sequence in x. - βx_dim_stride (
UInt32): Stride for dimension in x. - βstates_batch_stride (
UInt32): Stride for batch in states. - βstates_dim_stride (
UInt32): Stride for dimension in states. - βstates_seqlen_stride (
UInt32): Stride for sequence in states.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!