Skip to main content

Mojo module

mha_fa3_utils

comptime values

KVTMATile

comptime KVTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = depth] = TMATensorTile[dtype, _split_last_layout[dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)]

Parameters

  • dtype (DType):
  • swizzle_mode (TensorMapSwizzle):
  • BN (Int):
  • depth (Int):
  • BK (Int):

QTMATile

comptime QTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, group: Int, decoding: Bool] = TMATensorTile[dtype, _split_last_layout[dtype](q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding](), swizzle_mode, True), _ragged_desc_layout[dtype](q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding](), swizzle_mode)]

Parameters

  • dtype (DType):
  • swizzle_mode (TensorMapSwizzle):
  • BM (Int):
  • depth (Int):
  • group (Int):
  • decoding (Bool):

Structs

Traits

Functions

Was this page helpful?