IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SplitParallelCombineParams

struct SplitParallelCombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, has_attn_sink: Bool = False]

Fields​

  • ​out_accum_split_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​lse_accum_split_ptr (UnsafePointer[Scalar[accum_type], MutAnyOrigin]):
  • ​output_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​input_row_offsets_ptr (UnsafePointer[UInt32, MutAnyOrigin]):
  • ​attn_sink_ptr (OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]):
  • ​batch_size (Int):
  • ​seq_len (Int):
  • ​num_heads (Int):
  • ​head_dim (Int):
  • ​lse_stride_split (Int):
  • ​lse_stride_batch (Int):
  • ​lse_stride_seq (Int):
  • ​out_accum_stride_split (Int):
  • ​out_accum_stride_head (Int):
  • ​out_stride_row (Int):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = SplitParallelCombineParams[output_type, accum_type, num_splits, ragged, has_attn_sink]

heads_per_block​

comptime heads_per_block = 1

num_threads​

comptime num_threads = (8 * WARP_SIZE)

num_warps​

comptime num_warps = 8

Methods​

__init__​

def __init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

get_device_type_name​

static def get_device_type_name() -> String

Returns:

String