Skip to main content

Mojo struct

MHAConfig

struct MHAConfig[dtype: DType]

Fields​

  • ​num_heads (Int):
  • ​depth (Int):
  • ​padded_depth (Int):
  • ​num_queries_per_block (Int):
  • ​num_keys_per_block (Int):
  • ​BK (Int):
  • ​WM (Int):
  • ​WN (Int):
  • ​num_pipeline_stages (Int):
  • ​k_group_size (Int):
  • ​algorithm (FlashAttentionAlgorithm):
  • ​swizzle_mode (TensorMapSwizzle):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable, Writable

Methods​

__init__​

__init__(num_heads: Int, depth: Int, num_queries_per_block: Optional[Int] = None, num_keys_per_block: Optional[Int] = None, BK: Optional[Int] = None, WM: Optional[Int] = None, WN: Optional[Int] = None, num_pipeline_stages: Int = 4, k_group_size: Int = 1, algorithm: FlashAttentionAlgorithm = FlashAttentionAlgorithm(-1), swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B) -> Self

block_m​

block_m(self) -> Int

Returns:

Int

block_n​

block_n(self) -> Int

Returns:

Int

block_k​

block_k(self) -> Int

Returns:

Int

warp_m​

warp_m(self) -> Int

Returns:

Int

warp_n​

warp_n(self) -> Int

Returns:

Int

num_warps_m​

num_warps_m(self) -> Int

Returns:

Int

num_warps_n​

num_warps_n(self) -> Int

Returns:

Int

num_consumer_threads​

num_consumer_threads(self) -> Int

Returns:

Int

num_producer_threads​

num_producer_threads[producer_consumer_kernel: Bool = False](self) -> Int

Returns:

Int

num_threads​

num_threads[producer_consumer_kernel: Bool = False](self) -> Int

Returns:

Int

swizzle_granularity​

swizzle_granularity(self) -> Int

Returns:

Int

q_smem_size​

q_smem_size(self, fa3: Bool = False, persistent: Bool = False) -> Int

Returns:

Int

kv_smem_size​

kv_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

k_smem_size​

k_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

v_smem_size​

v_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

p_smem_size​

p_smem_size(self) -> Int

Returns:

Int

warp_scratch_smem_size​

warp_scratch_smem_size(self) -> Int

Returns:

Int

shared_mem_bytes​

shared_mem_bytes[shared_kv: Bool = False, sm_90: Bool = False](self) -> Int

Returns:

Int

write_to​

write_to(self, mut writer: T)