Mojo struct
MHAAttentionConfig
struct MHAAttentionConfig[token_gen: Bool, config: MHAConfig[config.dtype], group: Int]
Implemented traits
AnyType,
AttentionConfig,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
depth_padded
comptime depth_padded = False if MHAAttentionConfig[token_gen, config, group].use_gfx950_mha_kernel else True
double_buffer
comptime double_buffer = True if MHAAttentionConfig[token_gen, config, group].use_gfx950_mha_kernel else False
full_kv
comptime full_kv = True if MHAAttentionConfig[token_gen, config, group].use_gfx950_mha_kernel else False
shared_kv
comptime shared_kv = False if MHAAttentionConfig[token_gen, config, group].use_gfx950_mha_kernel else True
use_gfx950_mha_kernel
comptime use_gfx950_mha_kernel = (config == SIMD(64)) if (config == SIMD(64)) else (config == SIMD(128)) or (config == SIMD(256)) or (config == SIMD(512)) if not token_gen else not token_gen
Methods
q_head_idx
q_tile_idx
kv_head_idx
get_mma_shape
get_q_offset
get_output_offset
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!