Mojo struct
MHAConfig
@register_passable(trivial)
struct MHAConfig[dtype: DType]
Fields
- num_heads (
UInt): - depth (
UInt): - padded_depth (
UInt): - num_queries_per_block (
UInt): - num_keys_per_block (
UInt): - BK (
UInt): - WM (
UInt): - WN (
UInt): - num_pipeline_stages (
UInt): - k_group_size (
UInt): - algorithm (
FlashAttentionAlgorithm): - swizzle_mode (
TensorMapSwizzle):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility,
Writable
Aliases
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Methods
__init__
__init__(num_heads: UInt, depth: UInt, num_queries_per_block: OptionalReg[UInt] = None, num_keys_per_block: OptionalReg[UInt] = None, BK: OptionalReg[UInt] = None, WM: OptionalReg[UInt] = None, WN: OptionalReg[UInt] = None, num_pipeline_stages: UInt = 4, k_group_size: UInt = 1, algorithm: FlashAttentionAlgorithm = FlashAttentionAlgorithm(-1), swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B) -> Self
block_m
block_n
block_k
warp_m
warp_n
num_warps_m
num_warps_n
num_consumer_threads
num_producer_threads
num_threads
swizzle_granularity
q_smem_size
kv_smem_size
k_smem_size
v_smem_size
p_smem_size
warp_scratch_smem_size
shared_mem_bytes
__str__
write_to
write_to(self, mut writer: T)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!