Skip to main content

Mojo struct

FA4Config

@register_passable(trivial) struct FA4Config

Fields

  • MMA_M (Int):
  • BM (Int):
  • BN (Int):
  • BK0 (Int):
  • BK1 (Int):
  • depth (Int):
  • padded_depth (Int):
  • group (Int):
  • num_q_heads (Int):
  • num_kv_heads (Int):
  • TMEM_S1 (Int):
  • TMEM_O0 (Int):
  • TMEM_O1 (Int):
  • TMEM_P0 (Int):
  • TMEM_P1 (Int):
  • TMEM_C0 (Int):
  • TMEM_C1 (Int):
  • tmem_used (Int):
  • num_kv_stages (Int):
  • num_mma_stages (Int):
  • smem_used (Int):
  • dtype_size (Int):
  • split_m (Bool):
  • swizzle_mode (TensorMapSwizzle):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

mbar_size

comptime mbar_size = size_of[DType.int64]()

MMA_K

comptime MMA_K = 16

num_correction_cols

comptime num_correction_cols = 1

num_threads

comptime num_threads = 512

sm100_smem_carveout

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols

comptime sm100_tmem_cols = 512

TMEM_S0

comptime TMEM_S0 = 0

Methods

__init__

__init__(*, num_q_heads: Int, group: Int, depth: Int, dtype_size: Int, swizzle_mode: TensorMapSwizzle, page_size: Int) -> Self

num_qo

num_qo(self) -> Int

Returns:

Int

supported

supported(self) -> Bool

Returns:

Bool

use_tmem_for_correction

use_tmem_for_correction(self) -> Bool

Returns:

Bool

correction_smem_elements

correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group

num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group

num_active_threads_per_group(self) -> Int

Returns:

Int

Was this page helpful?