Mojo struct
SplitParallelCombineParams
struct SplitParallelCombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, has_attn_sink: Bool = False]
Fieldsβ
- βout_accum_split_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - βlse_accum_split_ptr (
UnsafePointer[Scalar[accum_type], MutAnyOrigin]): - βoutput_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - βinput_row_offsets_ptr (
UnsafePointer[UInt32, MutAnyOrigin]): - βattn_sink_ptr (
OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]): - βbatch_size (
Int): - βseq_len (
Int): - βnum_heads (
Int): - βhead_dim (
Int): - βlse_stride_split (
Int): - βlse_stride_batch (
Int): - βlse_stride_seq (
Int): - βout_accum_stride_split (
Int): - βout_accum_stride_head (
Int): - βout_stride_row (
Int):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = SplitParallelCombineParams[output_type, accum_type, num_splits, ragged, has_attn_sink]
heads_per_blockβ
comptime heads_per_block = 1
num_threadsβ
comptime num_threads = (8 * WARP_SIZE)
num_warpsβ
comptime num_warps = 8
Methodsβ
__init__β
__init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self
get_type_nameβ
get_device_type_nameβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!