Skip to main content

Mojo function

gemm_mma_cpasync

gemm_mma_cpasync[pdl_level: PDLLevel = PDLLevel(), tile_k: Int = 128, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: TileTensor[address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], act: TileTensor[address_space=act.address_space, linear_idx_type=act.linear_idx_type, element_size=act.element_size], weight: TileTensor[address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], gemm_m: Int, gemm_k: Int, gemm_n: Int, batch_size: Int, ctx: DeviceContext)

Launch the batched GEMM tensor-core kernel.

C[gemm_m, gemm_n] = act[gemm_m, K] x weight[gemm_n, K]^T.

Args: