Skip to main content

Mojo struct

SplitParallelCombineParams

struct SplitParallelCombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, has_attn_sink: Bool = False]

Fields​

  • ​out_accum_split_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​lse_accum_split_ptr (UnsafePointer[Scalar[accum_type], MutAnyOrigin]):
  • ​output_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​input_row_offsets_ptr (UnsafePointer[UInt32, MutAnyOrigin]):
  • ​attn_sink_ptr (OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]):
  • ​batch_size (Int):
  • ​seq_len (Int):
  • ​num_heads (Int):
  • ​head_dim (Int):
  • ​lse_stride_split (Int):
  • ​lse_stride_batch (Int):
  • ​lse_stride_seq (Int):
  • ​out_accum_stride_split (Int):
  • ​out_accum_stride_head (Int):
  • ​out_stride_row (Int):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = SplitParallelCombineParams[output_type, accum_type, num_splits, ragged, has_attn_sink]

heads_per_block​

comptime heads_per_block = 1

num_threads​

comptime num_threads = (8 * WARP_SIZE)

num_warps​

comptime num_warps = 8

Methods​

__init__​

__init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self

get_type_name​

static get_type_name() -> String

Returns:

String

get_device_type_name​

static get_device_type_name() -> String

Returns:

String