Skip to main content

Mojo struct

MHAAttentionConfig

struct MHAAttentionConfig[token_gen: Bool, config: MHAConfig, group: Int]

Implemented traits

AnyType, AttentionConfig, Copyable, ImplicitlyCopyable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

depth_padded

alias depth_padded = False if token_gen.__invert__() if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else True

double_buffer

alias double_buffer = True if token_gen.__invert__() if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else False

full_kv

alias full_kv = True if token_gen.__invert__() if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else False

shared_kv

alias shared_kv = False if token_gen.__invert__() if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else True

USE_EXPERIMENTAL_CDNA4_MHA_KERNEL

alias USE_EXPERIMENTAL_CDNA4_MHA_KERNEL = token_gen.__invert__() if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer()

Methods

q_head_idx

static q_head_idx() -> UInt

Returns:

UInt

q_tile_idx

static q_tile_idx() -> UInt

Returns:

UInt

kv_head_idx

static kv_head_idx() -> UInt

Returns:

UInt

get_mma_shape

static get_mma_shape() -> IndexList[3]

Returns:

IndexList

get_q_offset

static get_q_offset[q_depth: UInt]() -> UInt32

Returns:

UInt32

get_output_offset

static get_output_offset[output_depth: UInt]() -> UInt32

Returns:

UInt32

Was this page helpful?