Skip to main content

Mojo struct

CombineParams

struct CombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2, has_attn_sink: Bool = False]

Fields​

  • ​out_accum_split_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​lse_accum_split_ptr (UnsafePointer[Scalar[accum_type], MutAnyOrigin]):
  • ​output_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • ​input_row_offsets_ptr (UnsafePointer[UInt32, MutAnyOrigin]):
  • ​attn_sink_ptr (OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]):
  • ​batch_size (Int):
  • ​seq_len (Int):
  • ​num_heads (Int):
  • ​head_dim (Int):
  • ​lse_stride_split (Int):
  • ​lse_stride_batch (Int):
  • ​lse_stride_seq (Int):
  • ​out_accum_stride_split (Int):
  • ​out_accum_stride_head (Int):
  • ​out_stride_row (Int):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

device_type​

comptime device_type = CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head, has_attn_sink]

heads_per_block​

comptime heads_per_block = (8 // warps_per_head)

num_threads​

comptime num_threads = ((CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head, has_attn_sink].heads_per_block * warps_per_head) * WARP_SIZE)

Methods​

__init__​

__init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self

get_type_name​

static get_type_name() -> String

Returns:

String

get_device_type_name​

static get_device_type_name() -> String

Returns:

String