Skip to main content
Log in

Mojo function

load_matrix_a

load_matrix_a[m: Int, n: Int, k: Int](a_ptr: UnsafePointer[SIMD[float32, 1]], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[float32, 4]

Loads a tile of matrix A from memory to registers for TF32 tensor core operations.

Constraints:

The tile demensions must be m=16, n=8, k=8.

Parameters:

  • m (Int): Number of rows in the output matrix tile.
  • n (Int): Number of columns in the output matrix tile.
  • k (Int): Inner dimension for matrix multiplication.

Args:

  • a_ptr (UnsafePointer[SIMD[float32, 1]]): Pointer to matrix A data in memory.
  • tile_row (Int): Starting row index of the tile.
  • tile_col (Int): Starting column index of the tile.
  • ldm (Int): Leading dimension of matrix A (stride between rows).

Returns:

SIMD vector containing 4 TF32 values loaded from matrix A in the required order.

load_matrix_a[m: Int, n: Int, k: Int](a_ptr: UnsafePointer[SIMD[float16, 1]], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[float16, 4]

Loads a tile of matrix A from memory to registers for FP16 tensor core operations.

Constraints:

The tile demensions must be m=16, n=8, k=8.

Parameters:

  • m (Int): Number of rows in the output matrix tile.
  • n (Int): Number of columns in the output matrix tile.
  • k (Int): Inner dimension for matrix multiplication.

Args:

  • a_ptr (UnsafePointer[SIMD[float16, 1]]): Pointer to matrix A data in memory.
  • tile_row (Int): Starting row index of the tile.
  • tile_col (Int): Starting column index of the tile.
  • ldm (Int): Leading dimension of matrix A (stride between rows).

Returns:

SIMD vector containing 4 FP16 values loaded from matrix A in the required order.

load_matrix_a[m: Int, n: Int, k: Int](a_ptr: UnsafePointer[SIMD[bfloat16, 1]], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[bfloat16, (k // 2)]

Loads a tile of matrix A from memory to registers for BF16 tensor core operations.

Constraints:

The tile dimensions must be m=16, n=8, k=8 or m=16, n=8, k=16.

Parameters:

  • m (Int): Number of rows in the output matrix tile.
  • n (Int): Number of columns in the output matrix tile.
  • k (Int): Inner dimension for matrix multiplication.

Args:

  • a_ptr (UnsafePointer[SIMD[bfloat16, 1]]): Pointer to matrix A data in memory.
  • tile_row (Int): Starting row index of the tile.
  • tile_col (Int): Starting column index of the tile.
  • ldm (Int): Leading dimension of matrix A (stride between rows).

Returns:

SIMD vector containing k//2 BF16 values loaded from matrix A in the required order.