Mojo struct
CombineParams
struct CombineParams[output_type: DType, accum_type: DType, num_splits: Int, ragged: Bool = False, warps_per_head: Int = 2, has_attn_sink: Bool = False]
Fieldsβ
- βout_accum_split_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - βlse_accum_split_ptr (
UnsafePointer[Scalar[accum_type], MutAnyOrigin]): - βoutput_ptr (
UnsafePointer[Scalar[output_type], MutAnyOrigin]): - βinput_row_offsets_ptr (
UnsafePointer[UInt32, MutAnyOrigin]): - βattn_sink_ptr (
OptionalReg[UnsafePointer[Float32, MutAnyOrigin]]): - βbatch_size (
Int): - βseq_len (
Int): - βnum_heads (
Int): - βhead_dim (
Int): - βlse_stride_split (
Int): - βlse_stride_batch (
Int): - βlse_stride_seq (
Int): - βout_accum_stride_split (
Int): - βout_accum_stride_head (
Int): - βout_stride_row (
Int):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
device_typeβ
comptime device_type = CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head, has_attn_sink]
heads_per_blockβ
comptime heads_per_block = (8 // warps_per_head)
num_threadsβ
comptime num_threads = ((CombineParams[output_type, accum_type, num_splits, ragged, warps_per_head, has_attn_sink].heads_per_block * warps_per_head) * WARP_SIZE)
Methodsβ
__init__β
__init__(out_accum_split_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], lse_accum_split_ptr: UnsafePointer[Scalar[accum_type], MutAnyOrigin], output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], input_row_offsets_ptr: UnsafePointer[UInt32, MutAnyOrigin], attn_sink_ptr: OptionalReg[UnsafePointer[Float32, MutAnyOrigin]], batch_size: Int, seq_len: Int, num_heads: Int, head_dim: Int) -> Self
get_type_nameβ
get_device_type_nameβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!