Skip to main content

Mojo function

gemm_mma_cpasync_kernel

gemm_mma_cpasync_kernel[c_type: DType, a_type: DType, b_type: DType, c_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, *, tile_m: Int = 16, tile_n: Int = 8, tile_k: Int = 128, stage_cnt: Int = 2, accum_type: DType = DType.float32, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, pdl_level: PDLLevel = PDLLevel()](output: TileTensor[c_type, c_layout, MutAnyOrigin], act: TileTensor[a_type, a_layout, ImmutAnyOrigin], weight: TileTensor[b_type, b_layout, ImmutAnyOrigin], gemm_m: Int, gemm_k: Int, gemm_n: Int, batch_size: Int)