Mojo function
gemm_kernel_amd
gemm_kernel_amd[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, c_layout_int_type: DType, a_layout_int_type: DType, b_layout_int_type: DType, c_linear_idx_type: DType, a_linear_idx_type: DType, b_linear_idx_type: DType, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[DType, Int, Int](IndexList[2], SIMD[$0, $1]) capturing -> None] = OptionalReg[fn[DType, Int, Int](IndexList[2], SIMD[$0, $1]) capturing -> None]({:i1 0, 1})](c: LayoutTensor[c_type, c_layout, MutableAnyOrigin, layout_int_type=c_layout_int_type, linear_idx_type=c_linear_idx_type], a: LayoutTensor[a_type, a_layout, MutableAnyOrigin, layout_int_type=a_layout_int_type, linear_idx_type=a_linear_idx_type], b: LayoutTensor[b_type, b_layout, MutableAnyOrigin, layout_int_type=b_layout_int_type, linear_idx_type=b_linear_idx_type])
AMD-optimized GEMM kernel for matrix multiplication C = A * B.
This kernel implements an efficient matrix multiplication algorithm optimized for AMD GPUs, with hierarchical tiling and structured memory access patterns.
Parameters:
- c_type (
DType
): Data type for the output matrix C. - c_layout (
Layout
): Memory layout for matrix C. - a_type (
DType
): Data type for the input matrix A. - a_layout (
Layout
): Memory layout for matrix A. - b_type (
DType
): Data type for the input matrix B. - b_layout (
Layout
): Memory layout for matrix B. - transpose_b (
Bool
): Whether matrix B should be transposed. - c_layout_int_type (
DType
): Data type for the integer part of matrix C. - a_layout_int_type (
DType
): Data type for the integer part of matrix A. - b_layout_int_type (
DType
): Data type for the integer part of matrix B. - c_linear_idx_type (
DType
): Data type for the linear index of matrix C. - a_linear_idx_type (
DType
): Data type for the linear index of matrix A. - b_linear_idx_type (
DType
): Data type for the linear index of matrix B. - config (
MatmulConfig[a_type, b_type, c_type, transpose_b]
): GEMM configuration parameters (tile sizes, etc.). - elementwise_lambda_fn (
OptionalReg[fn[DType, Int, Int](IndexList[2], SIMD[$0, $1]) capturing -> None]
): Optional function to apply to output elements.
Args:
- c (
LayoutTensor[c_type, c_layout, MutableAnyOrigin, layout_int_type=c_layout_int_type, linear_idx_type=c_linear_idx_type]
): Output matrix C (result). - a (
LayoutTensor[a_type, a_layout, MutableAnyOrigin, layout_int_type=a_layout_int_type, linear_idx_type=a_linear_idx_type]
): Input matrix A. - b (
LayoutTensor[b_type, b_layout, MutableAnyOrigin, layout_int_type=b_layout_int_type, linear_idx_type=b_linear_idx_type]
): Input matrix B (must be transposed).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!