Mojo function
test_matmul_sm90_swapAB_comparison_v2
test_matmul_sm90_swapAB_comparison_v2[a_type: DType, b_type: DType, c_type: DType, BM: Int, BN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, num_pipeline_stages: UInt, num_consumer: UInt, k_group_size: UInt = 1, num_k_partitions: UInt = 1, partitioned_multicast: Bool = False, BM_SWAPAB: Int = BM, BN_SWAPAB: Int = BN, BK_SWAPAB: Int = BK, MMA_M_SWAPAB: Int = MMA_M, MMA_N_SWAPAB: Int = MMA_N, MMA_K_SWAPAB: Int = MMA_K, num_pipeline_stages_swapAB: UInt = num_pipeline_stages, num_consumer_swapAB: UInt = num_consumer, k_group_size_swapAB: UInt = k_group_size, num_k_partitions_swapAB: UInt = num_k_partitions, partitioned_multicast_swapAB: Bool = partitioned_multicast, use_vendor_reference: Bool = False, default_epilogue: Bool = False, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None](ctx: DeviceContext, m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim])
Compare matmul results between normal execution and swapAB execution.
This version accepts config parameters directly as compile-time values and builds configs internally.
Both compute: C[M,N] = A[M,K] @ B[N,K]^T swapAB internally swaps A/B and transposes C on store, but result should match.
Parameters:
- a_type (
DType): Data type of matrix A. - b_type (
DType): Data type of matrix B. - c_type (
DType): Data type of output matrix C. - BM (
Int): Block tile M dimension for normal kernel. - BN (
Int): Block tile N dimension for normal kernel. - BK (
Int): Block tile K dimension for normal kernel. - MMA_M (
Int): MMA M dimension for normal kernel. - MMA_N (
Int): MMA N dimension for normal kernel. - MMA_K (
Int): MMA K dimension for normal kernel. - num_pipeline_stages (
UInt): Number of pipeline stages for normal kernel. - num_consumer (
UInt): Number of consumers for normal kernel. - k_group_size (
UInt): K group size for normal kernel. - num_k_partitions (
UInt): Number of K partitions for normal kernel. - partitioned_multicast (
Bool): Partitioned multicast for normal kernel. - BM_SWAPAB (
Int): Block tile M dimension for swapAB kernel. - BN_SWAPAB (
Int): Block tile N dimension for swapAB kernel. - BK_SWAPAB (
Int): Block tile K dimension for swapAB kernel. - MMA_M_SWAPAB (
Int): MMA M dimension for swapAB kernel. - MMA_N_SWAPAB (
Int): MMA N dimension for swapAB kernel. - MMA_K_SWAPAB (
Int): MMA K dimension for swapAB kernel. - num_pipeline_stages_swapAB (
UInt): Number of pipeline stages for swapAB kernel. - num_consumer_swapAB (
UInt): Number of consumers for swapAB kernel. - k_group_size_swapAB (
UInt): K group size for swapAB kernel. - num_k_partitions_swapAB (
UInt): Number of K partitions for swapAB kernel. - partitioned_multicast_swapAB (
Bool): Partitioned multicast for swapAB kernel. - use_vendor_reference (
Bool): If True, use vendor matmul (cuBLAS) as reference instead of normal kernel. - default_epilogue (
Bool): If True, use default epilogue function that stores directly to output tensor. - elementwise_compute_lambda_fn (
OptionalReg): Optional compute lambda function to apply to each element before storing.
Args:
- ctx (
DeviceContext): The device context. - m (
ValOrDim): The M dimension (can be static or dynamic). - n (
ValOrDim): The N dimension (can be static or dynamic). - k (
ValOrDim): The K dimension (can be static or dynamic).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!