Mojo struct
MHAConfig
@register_passable(trivial)
struct MHAConfig[dtype: DType]
Fields
- num_heads (
UInt): - depth (
UInt): - padded_depth (
UInt): - num_queries_per_block (
UInt): - num_keys_per_block (
UInt): - BK (
UInt): - WM (
UInt): - WN (
UInt): - num_pipeline_stages (
UInt): - k_group_size (
UInt): - algorithm (
FlashAttentionAlgorithm): - swizzle_mode (
TensorMapSwizzle):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
Writable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
Methods
__init__
__init__(num_heads: Scalar[DType.uindex], depth: Scalar[DType.uindex], num_queries_per_block: OptionalReg[Scalar[DType.uindex]] = None, num_keys_per_block: OptionalReg[Scalar[DType.uindex]] = None, BK: OptionalReg[Scalar[DType.uindex]] = None, WM: OptionalReg[Scalar[DType.uindex]] = None, WN: OptionalReg[Scalar[DType.uindex]] = None, num_pipeline_stages: Scalar[DType.uindex] = 4, k_group_size: Scalar[DType.uindex] = 1, algorithm: FlashAttentionAlgorithm = FlashAttentionAlgorithm(-1), swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B) -> Self
block_m
block_m(self) -> UInt
Returns:
UInt
block_n
block_n(self) -> UInt
Returns:
UInt
block_k
block_k(self) -> UInt
Returns:
UInt
warp_m
warp_m(self) -> UInt
Returns:
UInt
warp_n
warp_n(self) -> UInt
Returns:
UInt
num_warps_m
num_warps_m(self) -> UInt
Returns:
UInt
num_warps_n
num_warps_n(self) -> UInt
Returns:
UInt
num_consumer_threads
num_consumer_threads(self) -> UInt
Returns:
UInt
num_producer_threads
num_producer_threads[producer_consumer_kernel: Bool = False](self) -> UInt
Returns:
UInt
num_threads
num_threads[producer_consumer_kernel: Bool = False](self) -> UInt
Returns:
UInt
swizzle_granularity
swizzle_granularity(self) -> UInt
Returns:
UInt
q_smem_size
q_smem_size(self, fa3: Bool = False, persistent: Bool = False) -> UInt
Returns:
UInt
kv_smem_size
kv_smem_size(self, fa3: Bool = False) -> UInt
Returns:
UInt
k_smem_size
k_smem_size(self, fa3: Bool = False) -> UInt
Returns:
UInt
v_smem_size
v_smem_size(self, fa3: Bool = False) -> UInt
Returns:
UInt
p_smem_size
p_smem_size(self) -> UInt
Returns:
UInt
warp_scratch_smem_size
warp_scratch_smem_size(self) -> UInt
Returns:
UInt
shared_mem_bytes
shared_mem_bytes[shared_kv: Bool = False, sm_90: Bool = False](self) -> UInt
Returns:
UInt
__str__
write_to
write_to(self, mut writer: T)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!