Mojo struct
TiledMmaOp
struct TiledMmaOp[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool = False]
TileTensor-native MMA operation for AMD attention kernels.
Wraps the raw GPU MMA intrinsic and operates directly on TileTensor register tiles.
Parametersβ
- βout_type (
DType): Accumulator data type. - βin_type (
DType): Input matrix element data type. - βshape (
IndexList[3]): MMA instruction shape [M, N, K]. - βtranspose_b (
Bool): Whether to transpose the B matrix.
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
a_frag_sizeβ
comptime a_frag_size = num_matrix_reg[shape[0], shape[2]]()
b_frag_sizeβ
comptime b_frag_size = num_matrix_reg[shape[2], shape[1]]()
c_frag_sizeβ
comptime c_frag_size = num_matrix_reg[shape[0], shape[1]]()
Methodsβ
mmaβ
static mma[swap_a_b: Bool = False](a: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size])
Perform MMA on TileTensor operands.
Tiles down to individual MMA fragments so the compiler can prove static shapes, then calls the raw gpu_mma intrinsic directly.
Parameters:
- βswap_a_b (
Bool): Whether to swap A and B operands. Only controls the argument order ofgpu_mma; accumulator indexing is always col-major over (M, N).
Args:
- βa (
TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size]): A operand tile [num_m_mmas, a_frag_size]. - βb (
TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size]): B operand tile [num_n_mmas, b_frag_size]. - βc (
TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size]): Accumulator tile [num_m_mmas * num_n_mmas, c_frag_size], modified in-place.
load_bβ
static load_b[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = 0)
Load B-matrix fragments from SMEM to registers.
Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Handles both transposed and non-transposed B layouts via comptime dispatch.
Parameters:
- βswizzle (
Optional[Swizzle]): Optional swizzle for SMEM bank-conflict avoidance.
Args:
- βwarp_tile (
TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size]): Source warp tile in shared memory. - βreg_tile (
TileTensor[in_type, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size]): Destination register tile for MMA fragments. - βk_group_idx (
Int): K-dimension group index within the warp tile.
load_aβ
static load_a[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = 0)
Load A-matrix fragments from SMEM to registers.
Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Always uses col_major thread distribution.
Parameters:
- βswizzle (
Optional[Swizzle]): Optional swizzle for SMEM access.
Args:
- βwarp_tile (
TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size]): Source warp tile in shared memory. - βreg_tile (
TileTensor[in_type, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size]): Destination register tile for MMA fragments. - βk_group_idx (
Int): K-dimension group index within the warp tile.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!