IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLASparseSharedMemoryFP8

struct MLASparseSharedMemoryFP8[config: MLASparseConfig[config.qkv_dtype, config.b_topk_, config.num_mbars_, config.q_smem_depth_, config.q_tmem_depth_], scale_block_size: Int]

Fields​

  • ​base (MLASparseSharedMemory[config]):
  • ​k_scales (InlineArray[Float32, MLASparseSharedMemoryFP8[config, scale_block_size].K_SCALES_SIZE]):
  • ​v_scales (InlineArray[Float32, MLASparseSharedMemoryFP8[config, scale_block_size].V_SCALES_SIZE]):
  • ​k_fp8_tma_done (InlineArray[SharedMemBarrier, 2]):
  • ​v_fp8_tma_done (InlineArray[SharedMemBarrier, 2]):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

K_scales_per_token​

comptime K_scales_per_token = ceildiv(config.qk_depth, scale_block_size)

K_SCALES_SIZE​

comptime K_SCALES_SIZE = (MLASparseSharedMemoryFP8[config, scale_block_size].TOPK_PER_CTA * MLASparseSharedMemoryFP8[config, scale_block_size].K_scales_per_token)

num_mbars​

comptime num_mbars = 2

TOPK_PER_CTA​

comptime TOPK_PER_CTA = (config.B_TOPK // 2)

V_scales_per_token​

comptime V_scales_per_token = ceildiv(config.v_depth, scale_block_size)

V_SCALES_SIZE​

comptime V_SCALES_SIZE = (config.B_TOPK * MLASparseSharedMemoryFP8[config, scale_block_size].V_scales_per_token)