Mojo function
smem_mma_subtile
smem_mma_subtile[mma_rows: Int, mma_cols: Int, BN: Int, BK: Int, dtype: DType](smem_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], bk_tile: Int, k_sub: Int, mma_idx: Int) -> TileTensor[dtype, Layout[ComptimeInt[mma_rows], ComptimeInt[mma_cols], ComptimeInt[mma_cols], ComptimeInt[1]], MutAnyOrigin, address_space=AddressSpace.SHARED]
Creates a flat TileTensor for an MMA-sized sub-tile in blocked SMEM.
Used by the non-transposed (V buffer) load_from_shared path. The V buffer's SMEM has shape (BN, depth) with blocked layout (num_repeats × BN×BK blocks). Each MMA tile is mma_rows × mma_cols within one block.
Parameters:
- mma_rows (
Int): MMA tile height (e.g., MMA_K=16). - mma_cols (
Int): MMA tile width (e.g., MMA_M=32). - BN (
Int): Block height. - BK (
Int): Block width. - dtype (
DType): Element data type.
Args:
- smem_ptr (
UnsafePointer): Base pointer to the SMEM allocation for this buffer stage. - bk_tile (
Int): Which BK-tall row group (0..depth/BK-1). - k_sub (
Int): Which MMA_K sub-row within the BK group (0..BK/MMA_K-1). - mma_idx (
Int): Linear MMA tile index across the full depth dimension.
Returns:
TileTensor: A TileTensor view into the MMA-sized sub-tile.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!