Skip to main content

Mojo function

test_matmul_sm90_swapAB_comparison_v2

test_matmul_sm90_swapAB_comparison_v2[a_type: DType, b_type: DType, c_type: DType, BM: Int, BN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, num_pipeline_stages: UInt, num_consumer: UInt, k_group_size: UInt = 1, num_k_partitions: UInt = 1, partitioned_multicast: Bool = False, BM_SWAPAB: Int = BM, BN_SWAPAB: Int = BN, BK_SWAPAB: Int = BK, MMA_M_SWAPAB: Int = MMA_M, MMA_N_SWAPAB: Int = MMA_N, MMA_K_SWAPAB: Int = MMA_K, num_pipeline_stages_swapAB: UInt = num_pipeline_stages, num_consumer_swapAB: UInt = num_consumer, k_group_size_swapAB: UInt = k_group_size, num_k_partitions_swapAB: UInt = num_k_partitions, partitioned_multicast_swapAB: Bool = partitioned_multicast, use_vendor_reference: Bool = False, default_epilogue: Bool = False, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None](ctx: DeviceContext, m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim])

Compare matmul results between normal execution and swapAB execution.

This version accepts config parameters directly as compile-time values and builds configs internally.

Both compute: C[M,N] = A[M,K] @ B[N,K]^T swapAB internally swaps A/B and transposes C on store, but result should match.

Parameters:

  • a_type (DType): Data type of matrix A.
  • b_type (DType): Data type of matrix B.
  • c_type (DType): Data type of output matrix C.
  • BM (Int): Block tile M dimension for normal kernel.
  • BN (Int): Block tile N dimension for normal kernel.
  • BK (Int): Block tile K dimension for normal kernel.
  • MMA_M (Int): MMA M dimension for normal kernel.
  • MMA_N (Int): MMA N dimension for normal kernel.
  • MMA_K (Int): MMA K dimension for normal kernel.
  • num_pipeline_stages (UInt): Number of pipeline stages for normal kernel.
  • num_consumer (UInt): Number of consumers for normal kernel.
  • k_group_size (UInt): K group size for normal kernel.
  • num_k_partitions (UInt): Number of K partitions for normal kernel.
  • partitioned_multicast (Bool): Partitioned multicast for normal kernel.
  • BM_SWAPAB (Int): Block tile M dimension for swapAB kernel.
  • BN_SWAPAB (Int): Block tile N dimension for swapAB kernel.
  • BK_SWAPAB (Int): Block tile K dimension for swapAB kernel.
  • MMA_M_SWAPAB (Int): MMA M dimension for swapAB kernel.
  • MMA_N_SWAPAB (Int): MMA N dimension for swapAB kernel.
  • MMA_K_SWAPAB (Int): MMA K dimension for swapAB kernel.
  • num_pipeline_stages_swapAB (UInt): Number of pipeline stages for swapAB kernel.
  • num_consumer_swapAB (UInt): Number of consumers for swapAB kernel.
  • k_group_size_swapAB (UInt): K group size for swapAB kernel.
  • num_k_partitions_swapAB (UInt): Number of K partitions for swapAB kernel.
  • partitioned_multicast_swapAB (Bool): Partitioned multicast for swapAB kernel.
  • use_vendor_reference (Bool): If True, use vendor matmul (cuBLAS) as reference instead of normal kernel.
  • default_epilogue (Bool): If True, use default epilogue function that stores directly to output tensor.
  • elementwise_compute_lambda_fn (OptionalReg): Optional compute lambda function to apply to each element before storing.

Args:

  • ctx (DeviceContext): The device context.
  • m (ValOrDim): The M dimension (can be static or dynamic).
  • n (ValOrDim): The N dimension (can be static or dynamic).
  • k (ValOrDim): The K dimension (can be static or dynamic).

Was this page helpful?