Mojo struct
TiledMmaOp
struct TiledMmaOp[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool = False]
TileTensor-native MMA operation for AMD attention kernels.
Wraps the raw GPU MMA intrinsic and operates directly on TileTensor register tiles.
Parameters
- out_type (
DType): Accumulator data type. - in_type (
DType): Input matrix element data type. - shape (
IndexList): MMA instruction shape [M, N, K]. - group_size (
Int): Number of MMA operations along the K dimension. - transpose_b (
Bool): Whether to transpose the B matrix.
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
a_frag_size
comptime a_frag_size = num_matrix_reg[shape[0], shape[2]]()
b_frag_size
comptime b_frag_size = num_matrix_reg[shape[2], shape[1]]()
c_frag_size
comptime c_frag_size = num_matrix_reg[shape[0], shape[1]]()
Methods
mma
static mma[swap_a_b: Bool = False](a: TileTensor[a.dtype, a.LayoutType, a.origin, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b.dtype, b.LayoutType, b.origin, address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[c.dtype, c.LayoutType, c.origin, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size])
Perform grouped MMA on TileTensor operands.
Tiles down to individual MMA fragments so the compiler can prove static shapes, then calls the raw gpu_mma intrinsic directly.
Parameters:
- swap_a_b (
Bool): Whether to swap A and B operands.
Args:
- a (
TileTensor): A operand tile [num_m_mmas, group_size * a_frag_size]. - b (
TileTensor): B operand tile [num_n_mmas, group_size * b_frag_size]. - c (
TileTensor): Accumulator tile [num_m_mmas * num_n_mmas, c_frag_size], modified in-place.
load_b
static load_b[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, warp_tile.LayoutType, warp_tile.origin, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, reg_tile.LayoutType, reg_tile.origin, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: UInt = 0)
Load B-matrix fragments from SMEM to registers.
Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Handles both transposed and non-transposed B layouts via comptime dispatch.
Parameters:
- swizzle (
Optional): Optional swizzle for SMEM bank-conflict avoidance.
Args:
- warp_tile (
TileTensor): Source warp tile in shared memory. - reg_tile (
TileTensor): Destination register tile for MMA fragments. - k_group_idx (
UInt): K-dimension group index within the warp tile.
load_a
static load_a[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, warp_tile.LayoutType, warp_tile.origin, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, reg_tile.LayoutType, reg_tile.origin, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: UInt = 0)
Load A-matrix fragments from SMEM to registers.
Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Always uses col_major thread distribution.
Parameters:
- swizzle (
Optional): Optional swizzle for SMEM access.
Args:
- warp_tile (
TileTensor): Source warp tile in shared memory. - reg_tile (
TileTensor): Destination register tile for MMA fragments. - k_group_idx (
UInt): K-dimension group index within the warp tile.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!