Skip to main content

Mojo function

smem_mma_subtile

smem_mma_subtile[mma_rows: Int, mma_cols: Int, BN: Int, BK: Int, dtype: DType](smem_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], bk_tile: Int, k_sub: Int, mma_idx: Int) -> TileTensor[dtype, Layout[ComptimeInt[mma_rows], ComptimeInt[mma_cols], ComptimeInt[mma_cols], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED]

Creates a flat TileTensor for an MMA-sized sub-tile in blocked SMEM.

Used by the non-transposed (V buffer) load_from_shared path. The V buffer's SMEM has shape (BN, depth) with blocked layout (num_repeats × BN×BK blocks). Each MMA tile is mma_rows × mma_cols within one block.

Parameters:

  • mma_rows (Int): MMA tile height (e.g., MMA_K=16).
  • mma_cols (Int): MMA tile width (e.g., MMA_M=32).
  • BN (Int): Block height.
  • BK (Int): Block width.
  • dtype (DType): Element data type.

Args:

  • smem_ptr (UnsafePointer): Base pointer to the SMEM allocation for this buffer stage.
  • bk_tile (Int): Which BK-tall row group (0..depth/BK-1).
  • k_sub (Int): Which MMA_K sub-row within the BK group (0..BK/MMA_K-1).
  • mma_idx (Int): Linear MMA tile index across the full depth dimension.

Returns:

TileTensor: A TileTensor view into the MMA-sized sub-tile.

Was this page helpful?