IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MHAConfig

struct MHAConfig[dtype: DType]

Fields​

  • ​num_heads (Int):
  • ​depth (Int):
  • ​padded_depth (Int):
  • ​num_queries_per_block (Int):
  • ​num_keys_per_block (Int):
  • ​BK (Int):
  • ​WM (Int):
  • ​WN (Int):
  • ​num_pipeline_stages (Int):
  • ​k_group_size (Int):
  • ​algorithm (FlashAttentionAlgorithm):
  • ​swizzle_mode (TensorMapSwizzle):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable, Writable

Methods​

__init__​

def __init__(num_heads: Int, depth: Int, num_queries_per_block: Optional[Int] = None, num_keys_per_block: Optional[Int] = None, BK: Optional[Int] = None, WM: Optional[Int] = None, WN: Optional[Int] = None, num_pipeline_stages: Int = Int(4), k_group_size: Int = Int(1), algorithm: FlashAttentionAlgorithm = FlashAttentionAlgorithm(Int(-1)), swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B) -> Self

block_m​

def block_m(self) -> Int

Returns:

Int

block_n​

def block_n(self) -> Int

Returns:

Int

block_k​

def block_k(self) -> Int

Returns:

Int

warp_m​

def warp_m(self) -> Int

Returns:

Int

warp_n​

def warp_n(self) -> Int

Returns:

Int

num_warps_m​

def num_warps_m(self) -> Int

Returns:

Int

num_warps_n​

def num_warps_n(self) -> Int

Returns:

Int

num_consumer_threads​

def num_consumer_threads(self) -> Int

Returns:

Int

num_producer_threads​

def num_producer_threads[producer_consumer_kernel: Bool = False](self) -> Int

Returns:

Int

num_threads​

def num_threads[producer_consumer_kernel: Bool = False](self) -> Int

Returns:

Int

swizzle_granularity​

def swizzle_granularity(self) -> Int

Returns:

Int

q_smem_size​

def q_smem_size(self, fa3: Bool = False, persistent: Bool = False) -> Int

Returns:

Int

kv_smem_size​

def kv_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

k_smem_size​

def k_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

v_smem_size​

def v_smem_size(self, fa3: Bool = False) -> Int

Returns:

Int

p_smem_size​

def p_smem_size(self) -> Int

Returns:

Int

warp_scratch_smem_size​

def warp_scratch_smem_size(self) -> Int

Returns:

Int

shared_mem_bytes​

def shared_mem_bytes[shared_kv: Bool = False, sm_90: Bool = False](self) -> Int

Returns:

Int

write_to​

def write_to(self, mut writer: T)