For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
TiledMmaOp
struct TiledMmaOp[out_type: DType, in_type: DType, shape: IndexList[Int(3)], transpose_b: Bool = False]
TileTensor-native MMA operation for AMD attention kernels.
Wraps the raw GPU MMA intrinsic and operates directly on TileTensor register tiles.
Parametersβ
- βout_type (
DType): Accumulator data type. - βin_type (
DType): Input matrix element data type. - βshape (
IndexList[Int(3)]): MMA instruction shape [M, N, K]. - βtranspose_b (
Bool): Whether to transpose the B matrix.
Implemented traitsβ
comptime membersβ
a_frag_sizeβ
comptime a_frag_size = num_matrix_reg[shape[Int(0)], shape[Int(2)]]()
b_frag_sizeβ
comptime b_frag_size = num_matrix_reg[shape[Int(2)], shape[Int(1)]]()
c_frag_sizeβ
comptime c_frag_size = num_matrix_reg[shape[Int(0)], shape[Int(1)]]()
Methodsβ
mmaβ
static def mma[swap_a_b: Bool = False](a: TileTensor[Storage=a.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[Storage=b.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[Storage=c.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size])
Perform MMA on TileTensor operands.
Tiles down to individual MMA fragments so the compiler can prove static shapes, then calls the raw gpu_mma intrinsic directly.
Parameters:
- βswap_a_b (
Bool): Whether to swap A and B operands. Only controls the argument order ofgpu_mma; accumulator indexing is always col-major over (M, N).
Args:
- βa (
TileTensor[Storage=a.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size]): A operand tile [num_m_mmas, a_frag_size]. - βb (
TileTensor[Storage=b.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size]): B operand tile [num_n_mmas, b_frag_size]. - βc (
TileTensor[Storage=c.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size]): Accumulator tile [num_m_mmas * num_n_mmas, c_frag_size], modified in-place.
load_bβ
static def load_b[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, Storage=warp_tile.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, Storage=reg_tile.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = Int(0))
Load B-matrix fragments from SMEM to registers.
Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Handles both transposed and non-transposed B layouts via comptime dispatch.
Parameters:
- βswizzle (
Optional[Swizzle]): Optional swizzle for SMEM bank-conflict avoidance.
Args:
- βwarp_tile (
TileTensor[in_type, Storage=warp_tile.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size]): Source warp tile in shared memory. - βreg_tile (
TileTensor[in_type, Storage=reg_tile.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size]): Destination register tile for MMA fragments. - βk_group_idx (
Int): K-dimension group index within the warp tile.
load_aβ
static def load_a[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, Storage=warp_tile.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, Storage=reg_tile.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = Int(0))
Load A-matrix fragments from SMEM to registers.
Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Always uses col_major thread distribution.
Parameters:
- βswizzle (
Optional[Swizzle]): Optional swizzle for SMEM access.
Args:
- βwarp_tile (
TileTensor[in_type, Storage=warp_tile.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size]): Source warp tile in shared memory. - βreg_tile (
TileTensor[in_type, Storage=reg_tile.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size]): Destination register tile for MMA fragments. - βk_group_idx (
Int): K-dimension group index within the warp tile.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!