Mojo struct
SplitParallelCombineParams
struct SplitParallelCombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, has_attn_sink: Bool = False]
Fields
- out_accum_split_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - lse_accum_split_ptr (
UnsafePointer[Scalar[accum_type], MutAnyOrigin]): - output_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - input_row_offsets_ptr (
UnsafePointer[UInt32, MutAnyOrigin]): - attn_sink_ptr (
OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]): - batch_size (
Int): - seq_len (
Int): - num_heads (
Int): - head_dim (
Int): - lse_stride_split (
Int): - lse_stride_batch (
Int): - lse_stride_seq (
Int): - out_accum_stride_split (
Int): - out_accum_stride_head (
Int): - out_stride_row (
Int):
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
device_type
comptime device_type = SplitParallelCombineParams[output_type, accum_type, num_splits, ragged, has_attn_sink]
heads_per_block
comptime heads_per_block = 1
num_threads
comptime num_threads = (8 * WARP_SIZE)
num_warps
comptime num_warps = 8
Methods
__init__
__init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self
get_type_name
get_device_type_name
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!