Skip to main content

Mojo function

load_matrix_b_amd

load_matrix_b_amd[m: Int, n: Int, k: Int](b_ptr: UnsafePointer[Float32, b_ptr.origin], tile_row: Int, tile_col: Int, ldm: Int) -> Float32

Loads a tile of matrix B from memory to registers for AMD FP32 tensor core operations.

Parameters:

  • m (Int): Number of rows in the output matrix tile.
  • n (Int): Number of columns in the output matrix tile.
  • k (Int): Inner dimension for matrix multiplication.

Args:

  • b_ptr (UnsafePointer): Pointer to matrix B data in memory.
  • tile_row (Int): Starting row index of the tile.
  • tile_col (Int): Starting column index of the tile.
  • ldm (Int): Leading dimension of matrix B (stride between rows).

Returns:

Float32: SIMD vector containing 1 FP32 value loaded from matrix B.

load_matrix_b_amd[dtype: DType, //, m: Int, n: Int, k: Int, n_blocks: Int = 1](b_ptr: UnsafePointer[Scalar[dtype], b_ptr.origin], tile_row: Int, tile_col: Int, ldm: Int, tile_loops: Int = 1) -> SIMD[dtype, 4]

Loads a tile of matrix B from memory to registers for AMD half-precision tensor core operations.

This function loads 4 consecutive values per thread from matrix B in a pattern optimized for AMD GPU tensor core operations. Each thread loads values based on its position within the warp.

Constraints:

The tile dimensions must be m=16, n=16, k=16 and n_blocks=1 or m=4, n=4, k=4 and n_blocks=16.

Parameters:

  • dtype (DType): Data type of the matrix elements (float16 or bfloat16).
  • m (Int): Number of rows in the output matrix tile.
  • n (Int): Number of columns in the output matrix tile.
  • k (Int): Inner dimension for matrix multiplication.
  • n_blocks (Int): Number of blocks.

Args:

  • b_ptr (UnsafePointer): Pointer to matrix B data in memory.
  • tile_row (Int): Starting row index of the tile.
  • tile_col (Int): Starting column index of the tile.
  • ldm (Int): Leading dimension of matrix B (stride between rows).
  • tile_loops (Int): Number of tile loops across matrix B's row dimension.

Returns:

SIMD: SIMD vector containing 4 values loaded from matrix B.

Was this page helpful?