Mojo struct
CombineParams
@register_passable(trivial)
struct CombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2]
Fields
- out_accum_split_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - lse_accum_split_ptr (
UnsafePointer[Scalar[accum_type], MutAnyOrigin]): - output_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - input_row_offsets_ptr (
UnsafePointer[UInt32, MutAnyOrigin]): - batch_size (
Int): - seq_len (
Int): - num_heads (
Int): - head_dim (
Int): - lse_stride_split (
Int): - lse_stride_batch (
Int): - lse_stride_seq (
Int): - out_accum_stride_split (
Int): - out_accum_stride_head (
Int): - out_stride_row (
Int):
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
device_type
comptime device_type = CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head]
heads_per_block
comptime heads_per_block = (8 // warps_per_head)
num_threads
comptime num_threads = ((CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head].heads_per_block * warps_per_head) * WARP_SIZE)
Methods
__init__
__init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self
get_type_name
get_device_type_name
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!