IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

CombineParams

struct CombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, warps_per_head: Int = Int(2), has_attn_sink: Bool = False]

Fields​

  • ​out_accum_split_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​lse_accum_split_ptr (UnsafePointer[Scalar[accum_type], MutAnyOrigin]):
  • ​output_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​input_row_offsets_ptr (UnsafePointer[UInt32, MutAnyOrigin]):
  • ​attn_sink_ptr (OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]):
  • ​batch_size (Int):
  • ​seq_len (Int):
  • ​num_heads (Int):
  • ​head_dim (Int):
  • ​lse_stride_split (Int):
  • ​lse_stride_batch (Int):
  • ​lse_stride_seq (Int):
  • ​out_accum_stride_split (Int):
  • ​out_accum_stride_head (Int):
  • ​out_stride_row (Int):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head, has_attn_sink]

heads_per_block​

comptime heads_per_block = (Int(8) // warps_per_head)

num_threads​

comptime num_threads = ((CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head, has_attn_sink].heads_per_block * warps_per_head) * _resolve_warp_size())

Methods​

__init__​

def __init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

get_device_type_name​

static def get_device_type_name() -> String

Returns:

String