Skip to main content

Mojo struct

CombineParams

@register_passable(trivial) struct CombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2]

Fields

  • out_accum_split_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • lse_accum_split_ptr (UnsafePointer[Scalar[accum_type], MutAnyOrigin]):
  • output_ptr (UnsafePointer[Scalar[output_type], MutAnyOrigin]):
  • input_row_offsets_ptr (UnsafePointer[UInt32, MutAnyOrigin]):
  • batch_size (Int):
  • seq_len (Int):
  • num_heads (Int):
  • head_dim (Int):
  • lse_stride_split (Int):
  • lse_stride_batch (Int):
  • lse_stride_seq (Int):
  • out_accum_stride_split (Int):
  • out_accum_stride_head (Int):
  • out_stride_row (Int):

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

device_type

comptime device_type = CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head]

heads_per_block

comptime heads_per_block = (8 // warps_per_head)

num_threads

comptime num_threads = ((CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head].heads_per_block * warps_per_head) * WARP_SIZE)

Methods

__init__

__init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self

get_type_name

static get_type_name() -> String

Returns:

String

get_device_type_name

static get_device_type_name() -> String

Returns:

String

Was this page helpful?