Skip to main content

Mojo struct

MHAConfig

struct MHAConfig[dtype: DType]

Fields

  • num_heads (Int):
  • depth (Int):
  • padded_depth (Int):
  • num_queries_per_block (Int):
  • num_keys_per_block (Int):
  • BK (Int):
  • WM (Int):
  • WN (Int):
  • num_pipeline_stages (Int):
  • k_group_size (Int):
  • algorithm (FlashAttentionAlgorithm):
  • swizzle_mode (TensorMapSwizzle):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable, Writable

Methods

__init__

__init__(num_heads: Int, depth: Int, num_queries_per_block: Optional[Int] = None, num_keys_per_block: Optional[Int] = None, BK: Optional[Int] = None, WM: Optional[Int] = None, WN: Optional[Int] = None, num_pipeline_stages: Int = 4, k_group_size: Int = 1, algorithm: FlashAttentionAlgorithm = FlashAttentionAlgorithm(-1), swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B) -> Self

block_m

block_m(self) -> Int

Returns:

Int

block_n

block_n(self) -> Int

Returns:

Int

block_k

block_k(self) -> Int

Returns:

Int

warp_m

warp_m(self) -> Int

Returns:

Int

warp_n

warp_n(self) -> Int

Returns:

Int

num_warps_m

num_warps_m(self) -> Int

Returns:

Int

num_warps_n

num_warps_n(self) -> Int

Returns:

Int

num_consumer_threads

num_consumer_threads(self) -> Int

Returns:

Int

num_producer_threads

num_producer_threads[producer_consumer_kernel: Bool = False](self) -> Int

Returns:

Int

num_threads

num_threads[producer_consumer_kernel: Bool = False](self) -> Int

Returns:

Int

swizzle_granularity

swizzle_granularity(self) -> Int

Returns:

Int

q_smem_size

q_smem_size(self, fa3: Bool = False, persistent: Bool = False) -> Int

Returns:

Int

kv_smem_size

kv_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

k_smem_size

k_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

v_smem_size

v_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

p_smem_size

p_smem_size(self) -> Int

Returns:

Int

warp_scratch_smem_size

warp_scratch_smem_size(self) -> Int

Returns:

Int

shared_mem_bytes

shared_mem_bytes[shared_kv: Bool = False, sm_90: Bool = False](self) -> Int

Returns:

Int

write_to

write_to(self, mut writer: T)

Was this page helpful?